一、實踐流程
1、數(shù)據(jù)預(yù)處理
主要是對訓(xùn)練數(shù)據(jù)進行隨機偏移、轉(zhuǎn)動等變換圖像處理,這樣可以盡可能讓訓(xùn)練數(shù)據(jù)多樣化
另外處理數(shù)據(jù)方式采用分批無序讀取的形式,避免了數(shù)據(jù)按目錄排序訓(xùn)練
- #數(shù)據(jù)準備
- def DataGen(self, dir_path, img_row, img_col, batch_size, is_train):
- if is_train:
- datagen = ImageDataGenerator(rescale=1./255,
- zoom_range=0.25, rotation_range=15.,
- channel_shift_range=25., width_shift_range=0.02, height_shift_range=0.02,
- horizontal_flip=True, fill_mode='constant')
- else:
- datagen = ImageDataGenerator(rescale=1./255)
-
- generator = datagen.flow_from_directory(
- dir_path, target_size=(img_row, img_col),
- batch_size=batch_size,
- shuffle=is_train)
-
- return generator
2、載入現(xiàn)有模型
這個部分是核心工作,目的是使用ImageNet訓(xùn)練出的權(quán)重來做我們的特征提取器,注意這里后面的分類層去掉
- base_model = InceptionV3(weights='imagenet', include_top=False, pooling=None,
- input_shape=(img_rows, img_cols, color),
- classes=nb_classes)
然后是凍結(jié)這些層,因為是訓(xùn)練好的
- for layer in base_model.layers:
- layer.trainable = False
而分類部分,需要我們根據(jù)現(xiàn)有需求來新定義的,這里可以根據(jù)實際情況自己進行調(diào)整,比如這樣
- x = base_model.output
- # 添加自己的全鏈接分類層
- x = GlobalAveragePooling2D()(x)
- x = Dense(1024, activation='relu')(x)
- predictions = Dense(nb_classes, activation='softmax')(x)
或者
- x = base_model.output
- #添加自己的全鏈接分類層
- x = Flatten()(x)
- predictions = Dense(nb_classes, activation='softmax')(x)
3、訓(xùn)練模型
這里我們用fit_generator函數(shù),它可以避免了一次性加載大量的數(shù)據(jù),并且生成器與模型將并行執(zhí)行以提高效率。比如可以在CPU上進行實時的數(shù)據(jù)提升,同時在GPU上進行模型訓(xùn)練
- history_ft = model.fit_generator(
- train_generator,
- steps_per_epoch=steps_per_epoch,
- epochs=epochs,
- validation_data=validation_generator,
- validation_steps=validation_steps)
二、貓狗大戰(zhàn)數(shù)據(jù)集
訓(xùn)練數(shù)據(jù)540M,測試數(shù)據(jù)270M,大家可以去官網(wǎng)下載
https://www./c/dogs-vs-cats-redux-kernels-edition/data
下載后把數(shù)據(jù)分成dog和cat兩個目錄來存放
三、訓(xùn)練
訓(xùn)練的時候會自動去下權(quán)值,比如vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5,但是如果我們已經(jīng)下載好了的話,可以改源代碼,讓他直接讀取我們的下載好的權(quán)值,比如在resnet50.py中
1、VGG19
vgg19的深度有26層,參數(shù)達到了549M,原模型最后有3個全連接層做分類器所以我還是加了一個1024的全連接層,訓(xùn)練10輪的情況達到了89%
2、ResNet50
ResNet50的深度達到了168層,但是參數(shù)只有99M,分類模型我就簡單點,一層直接分類,訓(xùn)練10輪的達到了96%的準確率
3、inception_v3
InceptionV3的深度159層,參數(shù)92M,訓(xùn)練10輪的結(jié)果
這是一層直接分類的結(jié)果
這是加了一個512全連接的,大家可以隨意調(diào)整測試
四、完整的代碼
- # -*- coding: utf-8 -*-
- import os
- from keras.utils import plot_model
- from keras.applications.resnet50 import ResNet50
- from keras.applications.vgg19 import VGG19
- from keras.applications.inception_v3 import InceptionV3
- from keras.layers import Dense,Flatten,GlobalAveragePooling2D
- from keras.models import Model,load_model
- from keras.optimizers import SGD
- from keras.preprocessing.image import ImageDataGenerator
- import matplotlib.pyplot as plt
-
- class PowerTransferMode:
- #數(shù)據(jù)準備
- def DataGen(self, dir_path, img_row, img_col, batch_size, is_train):
- if is_train:
- datagen = ImageDataGenerator(rescale=1./255,
- zoom_range=0.25, rotation_range=15.,
- channel_shift_range=25., width_shift_range=0.02, height_shift_range=0.02,
- horizontal_flip=True, fill_mode='constant')
- else:
- datagen = ImageDataGenerator(rescale=1./255)
-
- generator = datagen.flow_from_directory(
- dir_path, target_size=(img_row, img_col),
- batch_size=batch_size,
- #class_mode='binary',
- shuffle=is_train)
-
- return generator
-
- #ResNet模型
- def ResNet50_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True, is_plot_model=False):
- color = 3 if RGB else 1
- base_model = ResNet50(weights='imagenet', include_top=False, pooling=None, input_shape=(img_rows, img_cols, color),
- classes=nb_classes)
-
- #凍結(jié)base_model所有層,這樣就可以正確獲得bottleneck特征
- for layer in base_model.layers:
- layer.trainable = False
-
- x = base_model.output
- #添加自己的全鏈接分類層
- x = Flatten()(x)
- #x = GlobalAveragePooling2D()(x)
- #x = Dense(1024, activation='relu')(x)
- predictions = Dense(nb_classes, activation='softmax')(x)
-
- #訓(xùn)練模型
- model = Model(inputs=base_model.input, outputs=predictions)
- sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
- model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
-
- #繪制模型
- if is_plot_model:
- plot_model(model, to_file='resnet50_model.png',show_shapes=True)
-
- return model
-
-
- #VGG模型
- def VGG19_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True, is_plot_model=False):
- color = 3 if RGB else 1
- base_model = VGG19(weights='imagenet', include_top=False, pooling=None, input_shape=(img_rows, img_cols, color),
- classes=nb_classes)
-
- #凍結(jié)base_model所有層,這樣就可以正確獲得bottleneck特征
- for layer in base_model.layers:
- layer.trainable = False
-
- x = base_model.output
- #添加自己的全鏈接分類層
- x = GlobalAveragePooling2D()(x)
- x = Dense(1024, activation='relu')(x)
- predictions = Dense(nb_classes, activation='softmax')(x)
-
- #訓(xùn)練模型
- model = Model(inputs=base_model.input, outputs=predictions)
- sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
- model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
-
- # 繪圖
- if is_plot_model:
- plot_model(model, to_file='vgg19_model.png',show_shapes=True)
-
- return model
-
- # InceptionV3模型
- def InceptionV3_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True,
- is_plot_model=False):
- color = 3 if RGB else 1
- base_model = InceptionV3(weights='imagenet', include_top=False, pooling=None,
- input_shape=(img_rows, img_cols, color),
- classes=nb_classes)
-
- # 凍結(jié)base_model所有層,這樣就可以正確獲得bottleneck特征
- for layer in base_model.layers:
- layer.trainable = False
-
- x = base_model.output
- # 添加自己的全鏈接分類層
- x = GlobalAveragePooling2D()(x)
- x = Dense(1024, activation='relu')(x)
- predictions = Dense(nb_classes, activation='softmax')(x)
-
- # 訓(xùn)練模型
- model = Model(inputs=base_model.input, outputs=predictions)
- sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
- model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
-
- # 繪圖
- if is_plot_model:
- plot_model(model, to_file='inception_v3_model.png', show_shapes=True)
-
- return model
-
- #訓(xùn)練模型
- def train_model(self, model, epochs, train_generator, steps_per_epoch, validation_generator, validation_steps, model_url, is_load_model=False):
- # 載入模型
- if is_load_model and os.path.exists(model_url):
- model = load_model(model_url)
-
- history_ft = model.fit_generator(
- train_generator,
- steps_per_epoch=steps_per_epoch,
- epochs=epochs,
- validation_data=validation_generator,
- validation_steps=validation_steps)
- # 模型保存
- model.save(model_url,overwrite=True)
- return history_ft
-
- # 畫圖
- def plot_training(self, history):
- acc = history.history['acc']
- val_acc = history.history['val_acc']
- loss = history.history['loss']
- val_loss = history.history['val_loss']
- epochs = range(len(acc))
- plt.plot(epochs, acc, 'b-')
- plt.plot(epochs, val_acc, 'r')
- plt.title('Training and validation accuracy')
- plt.figure()
- plt.plot(epochs, loss, 'b-')
- plt.plot(epochs, val_loss, 'r-')
- plt.title('Training and validation loss')
- plt.show()
-
-
- if __name__ == '__main__':
- image_size = 197
- batch_size = 32
-
- transfer = PowerTransferMode()
-
- #得到數(shù)據(jù)
- train_generator = transfer.DataGen('data/cat_dog_Dataset/train', image_size, image_size, batch_size, True)
- validation_generator = transfer.DataGen('data/cat_dog_Dataset/test', image_size, image_size, batch_size, False)
-
- #VGG19
- #model = transfer.VGG19_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=False)
- #history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'vgg19_model_weights.h5', is_load_model=False)
-
- #ResNet50
- model = transfer.ResNet50_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=False)
- history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'resnet50_model_weights.h5', is_load_model=False)
-
- #InceptionV3
- #model = transfer.InceptionV3_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=True)
- #history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'inception_v3_model_weights.h5', is_load_model=False)
-
- # 訓(xùn)練的acc_loss圖
- transfer.plot_training(history_ft)
|