大家好,歡迎來(lái)到專欄《百戰(zhàn)GAN》,我們?cè)诠娞?hào)已經(jīng)輸出了非常多的GAN相關(guān)的理論,這一次我們開(kāi)設(shè)《百戰(zhàn)GAN》專欄,在這個(gè)專欄里,我們會(huì)進(jìn)行算法的核心思想講解,代碼的詳解,模型的訓(xùn)練等內(nèi)容。 作者&編輯 | 言有三 本文資源與生成結(jié)果展示 本文篇幅:5000字 背景要求:會(huì)使用Python,Tensorflow或者Pytorch 附帶資料:項(xiàng)目推薦,版本包括Pytorch+Tensorflow 同步平臺(tái):有三AI知識(shí)星球(一周內(nèi)) 1 項(xiàng)目背景 生成對(duì)抗網(wǎng)絡(luò)如今在計(jì)算機(jī)視覺(jué)的很多領(lǐng)域中都被廣泛應(yīng)用,需要每一個(gè)學(xué)習(xí)深度學(xué)習(xí)相關(guān)技術(shù)的算法人員掌握,我們公眾號(hào)和知識(shí)星球講述了非常多的理論知識(shí),在這個(gè)《百戰(zhàn)GAN》專欄中,我們會(huì)配合各類實(shí)戰(zhàn)案例來(lái)幫助大家進(jìn)行提升,本次項(xiàng)目開(kāi)發(fā)需要以下環(huán)境: (1) Linux系統(tǒng)或者windows系統(tǒng),使用Linux效率更高。 (2) 安裝好的Tensorflow,CPU或者GPU訓(xùn)練都可以。 2 原理簡(jiǎn)介 今天我們要實(shí)踐的模型是DCGAN和CGAN,DCGAN是第一個(gè)全卷積GAN,麻雀雖小,五臟俱全,最適合新人實(shí)踐。 DCGAN的生成器和判別器都采用了4層的網(wǎng)絡(luò)結(jié)構(gòu)。生成器網(wǎng)絡(luò)結(jié)構(gòu)如上圖所示,輸入為1×100的向量,然后經(jīng)過(guò)一個(gè)全連接層學(xué)習(xí),reshape為4×4×1024的張量,再經(jīng)過(guò)4個(gè)上采樣的反卷積網(wǎng)絡(luò)層,生成64×64的圖,各層的配置如下: 判別器輸入64×64大小的圖,經(jīng)過(guò)4次卷積,分辨率降低為4×4的大小,每一個(gè)卷積層的配置如下: DCGAN并不能控制生成圖片的類別,條件GAN(CGAN)則使用了條件控制變量作為輸入,是幾乎后續(xù)所有性能強(qiáng)大的GAN的基礎(chǔ)。網(wǎng)絡(luò)結(jié)構(gòu)如下,其中的y就是條件變量。 對(duì)于生成器來(lái)說(shuō),輸入包括z和y,兩者會(huì)進(jìn)行拼接后作為輸入。對(duì)于判別器來(lái)說(shuō),輸入包括了x和y,兩者會(huì)進(jìn)行拼接后作為輸入,當(dāng)然為了和z以及x進(jìn)行拼接,y需要做一些維度變換,即reshape操作。 關(guān)于它們的理論更加詳細(xì)的講解,大家可以移步有三AI知識(shí)星球,或者自行閱讀論文。 3 模型訓(xùn)練 接下來(lái)我們進(jìn)行實(shí)踐,選擇tensorflow框架,下面詳解具體的工程代碼,主要包括: (1) 生成器和判別器模型的定義。 (2) 損失和優(yōu)化目標(biāo)的定義。 3.1 DCGAN類定義 首先我們需要定義一個(gè)類,設(shè)計(jì)好輸入輸出,__init__函數(shù)如下: # 模型定義 class DCGAN(object): def __init__(self, sess, input_height=108, input_width=108, crop=True, batch_size=64, sample_num = 64, output_height=64, output_width=64, y_dim=None, z_dim=100, gf_dim=64, df_dim=64, gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default', max_to_keep=1, input_fname_pattern='*.jpg', checkpoint_dir='ckpts', sample_dir='samples', out_dir='./out', data_dir='./data'): 其中參數(shù)解釋如下:sess表示TensorFlow session,batch_size即批處理大??;z_dim是噪聲的維度,默認(rèn)為100;y_dim是一個(gè)可選的條件變量,比如分類標(biāo)簽,用于CGAN;gf_dim是生成器第一個(gè)卷積層的通道數(shù);df_dim是判別器第一個(gè)卷積層的通道數(shù);gfc_dim是生成器全連接層維度;dfc_dim是判別器全連接層維度;c_dim是輸入圖像維度,灰度圖為1,彩色圖為3。 從上述代碼可以看出,初始化函數(shù)__init__中配置了訓(xùn)練輸入圖尺寸,批處理大小,輸出圖尺寸,生成器的輸入維度,以及生成器和判別的卷積層和全連接層的若干維度變量。 |
|