作者丨薛潔婷 學(xué)校丨北京交通大學(xué)碩士生 研究方向丨圖像翻譯 研究動(dòng)機(jī)自 2014 年 Goodfellow 等人提出生成式對(duì)抗網(wǎng)絡(luò) (Generative Adversarial Networks, GAN) 以來,關(guān)于 GAN 穩(wěn)定訓(xùn)練的研究層出不窮,其中較為突出的是 2017 年提出的 Wasserstein GAN [1] 以及 2018 年提出的 SN-GAN [2]。其共同動(dòng)機(jī)都是通過使鑒別器滿足利普希茨(Lipschitz)限制條件(也就是讓鑒別器更加魯棒),從而提高模型的收斂速度以及穩(wěn)定性。 對(duì)抗訓(xùn)練 (Adversarial training) 作為提高模型魯棒性的經(jīng)典算法,被作者有效地結(jié)合至 GAN 的訓(xùn)練過程中并將結(jié)合后的模型命名為 Rob-GAN。實(shí)驗(yàn)表明 Rob-GAN 不僅能促使 GAN 的訓(xùn)練更加穩(wěn)定、生成結(jié)果更加逼真而且還縮減了對(duì)抗訓(xùn)練在訓(xùn)練集和測(cè)試集上的性能差距。另外,作者還從理論上分析了這一結(jié)果的本質(zhì)原因。 模型架構(gòu)首先我們來思考第一個(gè)問題:為什么 GAN 能夠改善對(duì)抗訓(xùn)練在測(cè)試集和訓(xùn)練集上的性能差距?在回答這個(gè)問題之前我們先來簡(jiǎn)單看一下對(duì)抗訓(xùn)練的過程。 對(duì)抗訓(xùn)練分為攻擊者和防御者,攻擊者是指通過對(duì)輸入樣本添加一些小的擾動(dòng)來“欺騙”分類器,讓其輸出錯(cuò)誤的分類結(jié)果。論文中作者采用了 PGD [3] 攻擊算法來產(chǎn)生對(duì)抗樣本,損失函數(shù)如 (1) 所示。其中 x+δ 表示對(duì)輸入樣本 x 添加一些小的擾動(dòng),f(x+δ,w) 是收到擾動(dòng)后分類器的輸出結(jié)果。 對(duì)于攻擊者而言希望受到擾動(dòng)后的分類器能輸出盡可能錯(cuò)的分類結(jié)果,也就是和真正的分類結(jié)果的損失要盡可能大。當(dāng)然,有攻擊者就肯定會(huì)有防御者,與攻擊相比,防御是一項(xiàng)更艱巨的任務(wù),特別是對(duì)于結(jié)合復(fù)雜模型的高維數(shù)據(jù)。防御者的損失函數(shù)如 (2) 所示。 目前對(duì)抗訓(xùn)練在小訓(xùn)練集(如 MNIST, CIFAR10)上可以訓(xùn)練出魯棒性強(qiáng)的分類器,然而一旦擴(kuò)展在大訓(xùn)練集(如 IMAGENET)上,分類器的效果將非常差,并且對(duì)抗訓(xùn)練的性能在訓(xùn)練集和測(cè)試集上的差距也很突出(如圖 1 所示),究其根本其實(shí)就是模型在測(cè)試集和訓(xùn)練集的魯棒性差異較大。 ▲ 圖1. 在不同水平攻擊下的準(zhǔn)確率 從理論上分析可知,如果在真實(shí)數(shù)據(jù)分布下模型的局部 LLV (local Lipschitz value) 越小,則模型的魯棒性越強(qiáng)。這一理論可以被描述為復(fù)合損失最小化問題(公式 3)。 但是在實(shí)際中我們并不能獲取真實(shí)數(shù)據(jù)分布 Pdata,因此一般采用先驗(yàn)分布來替換公式 3。實(shí)際上,如果我們的數(shù)據(jù)量足夠大并且假設(shè)集也設(shè)計(jì)的很合理,公式 4 最終會(huì)收斂于公式 3。 那么訓(xùn)練集中的約束的 LLV 會(huì)自動(dòng)泛化到測(cè)試集上嗎?很遺憾,答案是否定的。也就是說盡管我們能在訓(xùn)練集上有效的降低 LLV,但是對(duì)于測(cè)試集來說,這樣是無效的(如圖 2)。 ▲ 圖2. 測(cè)試集和訓(xùn)練集的局部Lipschitz值 (LLV) 比較 但是如果我們換個(gè)思路直接從真實(shí)數(shù)據(jù) Pdata 中采樣,那這個(gè)問題不就解決了嗎?看到這里你肯定很好奇,之前不是說 Pdata 無法獲取嗎?沒錯(cuò)!雖然我們沒法直接獲取其分布,但是 GAN 可以學(xué)?。?strong>也就是說我們先讓GAN去學(xué)習(xí) Pdata,然后對(duì)所學(xué)分布再進(jìn)行對(duì)抗訓(xùn)練。加入GAN后的損失函數(shù)如 5 所示。至此,我們解決了第一個(gè)問題。 接下來第二個(gè)問題是為什么加入對(duì)抗訓(xùn)練后可以促使 GAN 的訓(xùn)練更加穩(wěn)定?首先我們知道對(duì)抗樣本能夠很容易“欺騙”分類器,對(duì)于 CGAN 來說,生成器完全有可能模仿對(duì)抗樣本去“欺騙”鑒別器,就算是鑒別器能識(shí)別出一種模式的對(duì)抗樣本,但生成器很容易就能夠找到其他模式的對(duì)抗樣本,這樣的話最小最大化的游戲?qū)⒂肋h(yuǎn)不會(huì)停止,也就是生成器和鑒別器永遠(yuǎn)沒辦法達(dá)到納什均衡。 因此作者假設(shè),提高鑒別器的魯棒性對(duì)于穩(wěn)定 GAN 的訓(xùn)練至關(guān)重要。下面我們從理論上分析一下這一假設(shè)的成立的原因。 ▲ 圖3. 鑒別器的魯棒性 在 GAN 的訓(xùn)練中,生成器就類似于對(duì)抗訓(xùn)練中的“攻擊者”。如果鑒別器具有很小的 LLV (即很小),此時(shí),也就是說當(dāng)鑒別器受到攻擊時(shí),除非是擾動(dòng) δ 非常大,其并不會(huì)誤分類,如圖 3 所示。 假設(shè)在 t 時(shí)刻時(shí)鑒別器正確分類圖像為假圖即,在 t+1 時(shí)生成器如何才能使鑒別器誤分類呢?作者通過對(duì) D(x) 和 G(z;w) 進(jìn)行 Lipschitz 連續(xù)性假設(shè),可以得到一個(gè)下界: 我們發(fā)現(xiàn) LDLG 和成反比,也就是說如果鑒別器不魯棒的話即 LD 很大,那么只能讓生成器的參數(shù) w 移動(dòng)的非常小,才能保證其下界成立,此時(shí)模型就會(huì)收斂的很慢。因此,我們從理論上證明了鑒別器的魯棒性是影響 GAN 收斂速度的關(guān)鍵因素。 回顧 GAN 的發(fā)展歷史,無論是 WGAN 還是 SN-GAN 都要求鑒別器滿足全局 Lipschitz 條件限制,這無疑會(huì)降低模型的表達(dá)能力,因此作者提出要求在圖像流型上保持局部 Lipschitz 條件即可,而這一點(diǎn)通過對(duì)抗訓(xùn)練可以很容易地滿足。 經(jīng)過上面的分析我們發(fā)現(xiàn),對(duì)抗訓(xùn)練和 GAN 的結(jié)合是一個(gè)互幫互助的過程。在這個(gè)框架內(nèi)作者對(duì)生成器和鑒別器進(jìn)行端到端的訓(xùn)練:生成器向鑒別器提供假圖像; 同時(shí),從訓(xùn)練集采樣的真實(shí)圖像在發(fā)送到鑒別器之前由 PGD 攻擊算法預(yù)處理。其網(wǎng)絡(luò)架構(gòu)如圖 4 所示。 ▲ Figure 4. (LLV) 比較 Rob-GAN 的網(wǎng)絡(luò)架構(gòu) 實(shí)驗(yàn)在具體實(shí)驗(yàn)時(shí)鑒別器網(wǎng)路采用的是 AC-GAN 中的模型架構(gòu),只不過在 AC-GAN 中無論是生成器還是鑒別器都希望能最大化分類損失 LC,但這樣會(huì)導(dǎo)致即使生成器生成出特別差的樣本,損失函數(shù)還是希望其能正確分類。 因此作者將 LC 損失進(jìn)行了修改,也就是鑒別器希望盡可能正確分類真實(shí)樣本即最大化損失 LS+LC1,生成器希望能盡可能正確分類生成樣本即最大化 LC2-LS。 下面是在對(duì)抗訓(xùn)練采用 GAN 數(shù)據(jù)后的性能差距,可以明顯看出相較之前差距明顯縮小。 另外,作者對(duì) Rob-GAN 進(jìn)行了微調(diào)使鑒別器單獨(dú)執(zhí)行多分類問題以便能更好的比較 Rob-GAN 的效果。下面是 Rob-GAN 在 CIFAR10 以及 ImageNet 上不同擾動(dòng)情況下模型訓(xùn)練的準(zhǔn)確率,其中 FT 是指加入微調(diào)策略。 總結(jié)這篇論文作者將生成式對(duì)抗網(wǎng)絡(luò) (GAN) 以及對(duì)抗訓(xùn)練模型 (Adversarial training) 結(jié)合在一起形成一個(gè)全新的框架 Rob-GAN。從理論以及實(shí)驗(yàn)證明出 Rob-GAN 不僅能加速 GAN 收斂速度而且還有助于縮減對(duì)抗訓(xùn)練的性能差距,另外作者還重新定義了 AC-GAN 的損失函數(shù)。總之,我認(rèn)為這篇論文對(duì)于穩(wěn)定GAN訓(xùn)練具有重大意義,并且論文理論的嚴(yán)謹(jǐn)性也非常值得借鑒。 參考文獻(xiàn)[1]. M. Arjovsky, S. Chintala, and L. Bottou. Wasserstein gan.arXiv preprint arXiv:1701.07875, 2017. 2, 4 [2]. T. Miyato, T. Kataoka, M. Koyama, and Y. Yoshida. Spectral normalization for generative adversarial networks. In International Conference on Learning Representations, 2018 [3]. A. Madry, A. Makelov, L. Schmidt, D. Tsipras, and A. Vladu. Towards deep learning models resistant to adversarial attacks. arXiv preprint arXiv:1706.06083, 2017. 1, 2, 3, 6, 7 |
|