一区二区三区日韩精品-日韩经典一区二区三区-五月激情综合丁香婷婷-欧美精品中文字幕专区

分享

實戰(zhàn) 遷移學(xué)習(xí) VGG19、ResNet50、InceptionV3 實踐 貓狗大戰(zhàn) 問題

 昵稱56314485 2018-06-03

一、實踐流程

1、數(shù)據(jù)預(yù)處理

主要是對訓(xùn)練數(shù)據(jù)進行隨機偏移、轉(zhuǎn)動等變換圖像處理,這樣可以盡可能讓訓(xùn)練數(shù)據(jù)多樣化

另外處理數(shù)據(jù)方式采用分批無序讀取的形式,避免了數(shù)據(jù)按目錄排序訓(xùn)練

  1. #數(shù)據(jù)準備  
  2. def DataGen(self, dir_path, img_row, img_col, batch_size, is_train):  
  3.     if is_train:  
  4.         datagen = ImageDataGenerator(rescale=1./255,  
  5.             zoom_range=0.25, rotation_range=15.,  
  6.             channel_shift_range=25., width_shift_range=0.02, height_shift_range=0.02,  
  7.             horizontal_flip=True, fill_mode='constant')  
  8.     else:  
  9.         datagen = ImageDataGenerator(rescale=1./255)  
  10.   
  11.     generator = datagen.flow_from_directory(  
  12.         dir_path, target_size=(img_row, img_col),  
  13.         batch_size=batch_size,  
  14.         shuffle=is_train)  
  15.   
  16.     return generator  
2、載入現(xiàn)有模型

這個部分是核心工作,目的是使用ImageNet訓(xùn)練出的權(quán)重來做我們的特征提取器,注意這里后面的分類層去掉

  1. base_model = InceptionV3(weights='imagenet', include_top=False, pooling=None,  
  2.                            input_shape=(img_rows, img_cols, color),  
  3.                            classes=nb_classes)  

然后是凍結(jié)這些層,因為是訓(xùn)練好的

  1. for layer in base_model.layers:  
  2.     layer.trainable = False  
而分類部分,需要我們根據(jù)現(xiàn)有需求來新定義的,這里可以根據(jù)實際情況自己進行調(diào)整,比如這樣
  1. x = base_model.output  
  2. # 添加自己的全鏈接分類層  
  3. x = GlobalAveragePooling2D()(x)  
  4. x = Dense(1024, activation='relu')(x)  
  5. predictions = Dense(nb_classes, activation='softmax')(x)  
或者

  1. x = base_model.output  
  2.  #添加自己的全鏈接分類層  
  3.  x = Flatten()(x)  
  4.  predictions = Dense(nb_classes, activation='softmax')(x)  
3、訓(xùn)練模型

這里我們用fit_generator函數(shù),它可以避免了一次性加載大量的數(shù)據(jù),并且生成器與模型將并行執(zhí)行以提高效率。比如可以在CPU上進行實時的數(shù)據(jù)提升,同時在GPU上進行模型訓(xùn)練

  1. history_ft = model.fit_generator(  
  2. train_generator,  
  3. steps_per_epoch=steps_per_epoch,  
  4. epochs=epochs,  
  5. validation_data=validation_generator,  
  6. validation_steps=validation_steps)  

二、貓狗大戰(zhàn)數(shù)據(jù)集


訓(xùn)練數(shù)據(jù)540M,測試數(shù)據(jù)270M,大家可以去官網(wǎng)下載

https://www./c/dogs-vs-cats-redux-kernels-edition/data

下載后把數(shù)據(jù)分成dog和cat兩個目錄來存放

三、訓(xùn)練

訓(xùn)練的時候會自動去下權(quán)值,比如vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5,但是如果我們已經(jīng)下載好了的話,可以改源代碼,讓他直接讀取我們的下載好的權(quán)值,比如在resnet50.py中


1、VGG19

vgg19的深度有26層,參數(shù)達到了549M,原模型最后有3個全連接層做分類器所以我還是加了一個1024的全連接層,訓(xùn)練10輪的情況達到了89%


2、ResNet50

ResNet50的深度達到了168層,但是參數(shù)只有99M,分類模型我就簡單點,一層直接分類,訓(xùn)練10輪的達到了96%的準確率


3、inception_v3

InceptionV3的深度159層,參數(shù)92M,訓(xùn)練10輪的結(jié)果

這是一層直接分類的結(jié)果


這是加了一個512全連接的,大家可以隨意調(diào)整測試



四、完整的代碼

  1. # -*- coding: utf-8 -*-  
  2. import os  
  3. from keras.utils import plot_model  
  4. from keras.applications.resnet50 import ResNet50  
  5. from keras.applications.vgg19 import VGG19  
  6. from keras.applications.inception_v3 import InceptionV3  
  7. from keras.layers import Dense,Flatten,GlobalAveragePooling2D  
  8. from keras.models import Model,load_model  
  9. from keras.optimizers import SGD  
  10. from keras.preprocessing.image import ImageDataGenerator  
  11. import matplotlib.pyplot as plt  
  12.   
  13. class PowerTransferMode:  
  14.     #數(shù)據(jù)準備  
  15.     def DataGen(self, dir_path, img_row, img_col, batch_size, is_train):  
  16.         if is_train:  
  17.             datagen = ImageDataGenerator(rescale=1./255,  
  18.                 zoom_range=0.25, rotation_range=15.,  
  19.                 channel_shift_range=25., width_shift_range=0.02, height_shift_range=0.02,  
  20.                 horizontal_flip=True, fill_mode='constant')  
  21.         else:  
  22.             datagen = ImageDataGenerator(rescale=1./255)  
  23.   
  24.         generator = datagen.flow_from_directory(  
  25.             dir_path, target_size=(img_row, img_col),  
  26.             batch_size=batch_size,  
  27.             #class_mode='binary',  
  28.             shuffle=is_train)  
  29.   
  30.         return generator  
  31.   
  32.     #ResNet模型  
  33.     def ResNet50_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True, is_plot_model=False):  
  34.         color = 3 if RGB else 1  
  35.         base_model = ResNet50(weights='imagenet', include_top=False, pooling=None, input_shape=(img_rows, img_cols, color),  
  36.                               classes=nb_classes)  
  37.   
  38.         #凍結(jié)base_model所有層,這樣就可以正確獲得bottleneck特征  
  39.         for layer in base_model.layers:  
  40.             layer.trainable = False  
  41.   
  42.         x = base_model.output  
  43.         #添加自己的全鏈接分類層  
  44.         x = Flatten()(x)  
  45.         #x = GlobalAveragePooling2D()(x)  
  46.         #x = Dense(1024, activation='relu')(x)  
  47.         predictions = Dense(nb_classes, activation='softmax')(x)  
  48.   
  49.         #訓(xùn)練模型  
  50.         model = Model(inputs=base_model.input, outputs=predictions)  
  51.         sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)  
  52.         model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])  
  53.   
  54.         #繪制模型  
  55.         if is_plot_model:  
  56.             plot_model(model, to_file='resnet50_model.png',show_shapes=True)  
  57.   
  58.         return model  
  59.   
  60.   
  61.     #VGG模型  
  62.     def VGG19_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True, is_plot_model=False):  
  63.         color = 3 if RGB else 1  
  64.         base_model = VGG19(weights='imagenet', include_top=False, pooling=None, input_shape=(img_rows, img_cols, color),  
  65.                               classes=nb_classes)  
  66.   
  67.         #凍結(jié)base_model所有層,這樣就可以正確獲得bottleneck特征  
  68.         for layer in base_model.layers:  
  69.             layer.trainable = False  
  70.   
  71.         x = base_model.output  
  72.         #添加自己的全鏈接分類層  
  73.         x = GlobalAveragePooling2D()(x)  
  74.         x = Dense(1024, activation='relu')(x)  
  75.         predictions = Dense(nb_classes, activation='softmax')(x)  
  76.   
  77.         #訓(xùn)練模型  
  78.         model = Model(inputs=base_model.input, outputs=predictions)  
  79.         sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)  
  80.         model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])  
  81.   
  82.         # 繪圖  
  83.         if is_plot_model:  
  84.             plot_model(model, to_file='vgg19_model.png',show_shapes=True)  
  85.   
  86.         return model  
  87.   
  88.     # InceptionV3模型  
  89.     def InceptionV3_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True,  
  90.                     is_plot_model=False):  
  91.         color = 3 if RGB else 1  
  92.         base_model = InceptionV3(weights='imagenet', include_top=False, pooling=None,  
  93.                            input_shape=(img_rows, img_cols, color),  
  94.                            classes=nb_classes)  
  95.   
  96.         # 凍結(jié)base_model所有層,這樣就可以正確獲得bottleneck特征  
  97.         for layer in base_model.layers:  
  98.             layer.trainable = False  
  99.   
  100.         x = base_model.output  
  101.         # 添加自己的全鏈接分類層  
  102.         x = GlobalAveragePooling2D()(x)  
  103.         x = Dense(1024, activation='relu')(x)  
  104.         predictions = Dense(nb_classes, activation='softmax')(x)  
  105.   
  106.         # 訓(xùn)練模型  
  107.         model = Model(inputs=base_model.input, outputs=predictions)  
  108.         sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)  
  109.         model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])  
  110.   
  111.         # 繪圖  
  112.         if is_plot_model:  
  113.             plot_model(model, to_file='inception_v3_model.png', show_shapes=True)  
  114.   
  115.         return model  
  116.   
  117.     #訓(xùn)練模型  
  118.     def train_model(self, model, epochs, train_generator, steps_per_epoch, validation_generator, validation_steps, model_url, is_load_model=False):  
  119.         # 載入模型  
  120.         if is_load_model and os.path.exists(model_url):  
  121.             model = load_model(model_url)  
  122.   
  123.         history_ft = model.fit_generator(  
  124.             train_generator,  
  125.             steps_per_epoch=steps_per_epoch,  
  126.             epochs=epochs,  
  127.             validation_data=validation_generator,  
  128.             validation_steps=validation_steps)  
  129.         # 模型保存  
  130.         model.save(model_url,overwrite=True)  
  131.         return history_ft  
  132.   
  133.     # 畫圖  
  134.     def plot_training(self, history):  
  135.       acc = history.history['acc']  
  136.       val_acc = history.history['val_acc']  
  137.       loss = history.history['loss']  
  138.       val_loss = history.history['val_loss']  
  139.       epochs = range(len(acc))  
  140.       plt.plot(epochs, acc, 'b-')  
  141.       plt.plot(epochs, val_acc, 'r')  
  142.       plt.title('Training and validation accuracy')  
  143.       plt.figure()  
  144.       plt.plot(epochs, loss, 'b-')  
  145.       plt.plot(epochs, val_loss, 'r-')  
  146.       plt.title('Training and validation loss')  
  147.       plt.show()  
  148.   
  149.   
  150. if __name__ == '__main__':  
  151.     image_size = 197  
  152.     batch_size = 32  
  153.   
  154.     transfer = PowerTransferMode()  
  155.   
  156.     #得到數(shù)據(jù)  
  157.     train_generator = transfer.DataGen('data/cat_dog_Dataset/train', image_size, image_size, batch_size, True)  
  158.     validation_generator = transfer.DataGen('data/cat_dog_Dataset/test', image_size, image_size, batch_size, False)  
  159.   
  160.     #VGG19  
  161.     #model = transfer.VGG19_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=False)  
  162.     #history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'vgg19_model_weights.h5', is_load_model=False)  
  163.   
  164.     #ResNet50  
  165.     model = transfer.ResNet50_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=False)  
  166.     history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60'resnet50_model_weights.h5', is_load_model=False)  
  167.   
  168.     #InceptionV3  
  169.     #model = transfer.InceptionV3_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=True)  
  170.     #history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'inception_v3_model_weights.h5', is_load_model=False)  
  171.   
  172.     # 訓(xùn)練的acc_loss圖  
  173.     transfer.plot_training(history_ft)  



    本站是提供個人知識管理的網(wǎng)絡(luò)存儲空間,所有內(nèi)容均由用戶發(fā)布,不代表本站觀點。請注意甄別內(nèi)容中的聯(lián)系方式、誘導(dǎo)購買等信息,謹防詐騙。如發(fā)現(xiàn)有害或侵權(quán)內(nèi)容,請點擊一鍵舉報。
    轉(zhuǎn)藏 分享 獻花(0

    0條評論

    發(fā)表

    請遵守用戶 評論公約

    類似文章 更多

    国产爆操白丝美女在线观看| 女厕偷窥一区二区三区在线| 狠狠干狠狠操亚洲综合| 日韩欧美91在线视频| 四十女人口红哪个色好看| 视频在线观看色一区二区| 亚洲最新中文字幕在线视频| 久久精品伊人一区二区| 国产精品丝袜美腿一区二区| 日韩成人h视频在线观看 | 精品人妻一区二区三区四区久久 | 国内女人精品一区二区三区| 国产精品人妻熟女毛片av久| 99久久精品视频一区二区| 欧美自拍系列精品在线| 日韩欧美一区二区黄色| 国产精品久久久久久久久久久痴汉| 欧美特色特黄一级大黄片| 国产91色综合久久高清| 美日韩一区二区精品系列| 欧美日本精品视频在线观看| 超薄肉色丝袜脚一区二区| 午夜福利大片亚洲一区| 日本一级特黄大片国产| 欧美日不卡无在线一区| 91欧美亚洲精品在线观看| 亚洲视频一区二区久久久| 99久久成人精品国产免费| 丝袜视频日本成人午夜视频| 中文字幕高清免费日韩视频| 亚洲国产黄色精品在线观看| 国产精品久久女同磨豆腐| 99精品人妻少妇一区二区人人妻| 亚洲国产综合久久天堂| 在线亚洲成人中文字幕高清| 大尺度激情福利视频在线观看| 高清亚洲精品中文字幕乱码| 97人妻精品一区二区三区免| 亚洲精品中文字幕一二三| 欧美日韩精品一区免费 | 日韩欧美国产亚洲一区|