干貨|全面分析GAN,以及如何用TF實(shí)現GAN?
生成式模型能夠作為一種技術(shù)手段輔助強化學(xué)習,能夠有效表征強化學(xué)習模型中的state狀態(tài)。
前言
本文會(huì )從頭了解生成對抗式網(wǎng)絡(luò )的一些內容,從生成式模型開(kāi)始說(shuō)起,到GAN的基本原理,InfoGAN,AC-GAN的基本科普,如果有任何有錯誤的地方,請隨時(shí)噴,我 剛開(kāi)始研究GAN這塊的內容,希望和大家一起學(xué)習這塊內容。
生成式模型
何為生成式模型?在很多machine learning的教程或者公開(kāi)課上,通常會(huì )把machine learning的算法分為兩類(lèi): 生成式模型、判別式模型;其區別在于: 對于輸入x,類(lèi)別標簽y,在生成式模型中估計其聯(lián)合概率分布,而判別式模型估計其屬于某類(lèi)的條件概率分布。 常見(jiàn)的判別式模型包括:LogisticRegression, SVM, Neural Network等等,生成式模型包括:Naive Bayes, GMM, Bayesian Network, MRF 等等
研究生成式模型的意義
生成式模型的特性主要包括以下幾個(gè)方面:
在應用數學(xué)和工程方面,生成式模型能夠有效地表征高維數據分布;
生成式模型能夠作為一種技術(shù)手段輔助強化學(xué)習,能夠有效表征強化學(xué)習模型中的state狀態(tài)(這里不擴展,后面會(huì )跟RL的學(xué)習筆記);
對semi-supervised learning也有比較好的效果,能夠在miss data下訓練模型,并在miss data下給出相應地輸出;
在對于一個(gè)輸入伴隨多個(gè)輸出的場(chǎng)景下,生成式模型也能夠有效工作,而傳統的機器學(xué)習方法通過(guò)最小化模型輸出和期望輸出的某個(gè)object function的值 無(wú)法訓練單輸入多輸出的模型,而生成式模型,尤其是GAN能夠hold住這種場(chǎng)景,一個(gè)典型的應用是通過(guò)場(chǎng)景預測video的下一幀;
生成式模型一些典型的應用:
圖像的超分辨率
iGAN:Generative Visual Manipulation on the Natural Image Manifold
圖像轉換
生成式模型族譜
上圖涵蓋了基本的生成式模型的方法,主要按是否需要定義概率密度函數分為:
Explicit density models
explicit density models 又分為tractable explicit models和逼近的explicit model,怎么理解呢,tractable explicit model通??梢灾苯油ㄟ^(guò)數學(xué)方法來(lái)建模求解,而基于逼近的explicit model通常無(wú)法直接對數據分布進(jìn)行建模,可以利用數學(xué)里的一些近似方法來(lái)做數據建模, 通?;诒平膃xplicit model分為確定性(變分方法:如VAE的lower bound)和隨機性的方法(馬爾科夫鏈蒙特卡洛方法)。
VAE lower bound:
馬爾科夫鏈蒙特卡洛方法(MCMC),一種經(jīng)典的基于馬爾科夫鏈的抽樣方法,通過(guò)多次來(lái)擬合分布。比較好的教程:A Beginner’s Guide to Monte Carlo Markov Chain MCMC Analysis, An Introduction to MCMC for Machine Learning.
Implicit density models
無(wú)需定義明確的概率密度函數,代表方法包括馬爾科夫鏈、生成對抗式網(wǎng)絡(luò )(GAN),該系列方法無(wú)需定義數據分布的描述函數。
生成對抗式網(wǎng)絡(luò )與其他生成式網(wǎng)絡(luò )對比
生成對抗式網(wǎng)絡(luò )(GAN)能夠有效地解決很多生成式方法的缺點(diǎn),主要包括:
并行產(chǎn)生samples;
生成式函數的限制少,如無(wú)需合適馬爾科夫采樣的數據分布(Boltzmann machines),生成式函數無(wú)需可逆、latent code需與sample同維度(nonlinear ICA);
無(wú)需馬爾科夫鏈的方法(Boltzmann machines, GSNs);
相對于VAE的方法,無(wú)需variational bound;
GAN比其他方法一般來(lái)說(shuō)性能更好。
GAN工作原理
GAN主要由兩部分構成:generator和discriminator,generator主要是從訓練數據中產(chǎn)生相同分布的samples,而discriminator 則是判斷輸入是真實(shí)數據還是generator生成的數據,discriminator采用傳統的監督學(xué)習的方法。這里我們可以這樣類(lèi)比,generator 是一個(gè)偽造假幣的專(zhuān)業(yè)人士,discriminator是警察,generator的目的是制造出盡可能以假亂真的假鈔,而discriminator是為了能 鑒別是否為假鈔,最終整個(gè)gan會(huì )達到所謂的納什均衡,Goodfellow在他的paper中有嚴格的數學(xué)證明,當$p_G$==$p_{data}$時(shí)達到 全局最優(yōu):
另一個(gè)比較明顯看得懂的圖如下:
圖中黑色點(diǎn)線(xiàn)為真實(shí)數據分布$p_{data}$,綠色線(xiàn)為generator生成的數據分布$p_{G}$,而Discriminator就是藍色點(diǎn)線(xiàn),其目的是為了將$p_{data}$和$p_{G}$ 區分,(a)中是初始狀態(tài),然后會(huì )更新Discriminator中的參數,若干次step之后,Discriminator有了較大的判斷力即到了(b)的狀態(tài),之后會(huì )更新G的模型使其生成的數據分布(綠色線(xiàn))更加趨近與真實(shí)數據分布, 若干次G和D的模型參數更新后,理論上最終會(huì )達到(d)的狀態(tài)即G能夠產(chǎn)生和真實(shí)數據完全一致的分布(證明見(jiàn)上一張圖),如從隨機數據分布生成人臉像。
如何訓練GAN
因為GAN結構的不同,和常規訓練一個(gè)dl model方法不同, 這里采用simultaneous SGD,每一個(gè)step中,會(huì )有兩個(gè)兩個(gè)梯度優(yōu)化的 過(guò)程,一個(gè)是更新discriminator的參數來(lái)最小化$J_{(D)}$,一個(gè)是更新generator的參數來(lái)最小$J_{(G)}$,通常會(huì )選用Adam來(lái)作為最優(yōu)化的優(yōu)化器, 也有人建議可以不等次數地更新generator和discriminator(有相關(guān)工作提出,1:1的在實(shí)際中更有效:Adam: A Method for Stochastic Optimization) 如何訓練GAN,在Goodfellow的GAN的tutorial還有一些代碼中有更多的描述包括不同的cost function, 這里我就不詳細展開(kāi)了。
DCGAN
GAN出來(lái)后很多相關(guān)的應用和方法都是基于DCGAN的結構,DCGAN即”Deep Convolution GAN”,通常會(huì )有一些約定俗成的規則:
在Discriminator和generator中大部分層都使用batch normalization,而在最后一層時(shí)通常不會(huì )使用batch normalizaiton,目的 是為了保證模型能夠學(xué)習到數據的正確的均值和方差;
因為會(huì )從random的分布生成圖像,所以一般做需要增大圖像的空間維度時(shí)如77->1414, 一般會(huì )使用strdie為2的deconv(transposed convolution);
通常在DCGAN中會(huì )使用Adam優(yōu)化算法而不是SGD。
各種GAN
這里有個(gè)大神把各種gan的paper都做了一個(gè)統計AdversarialNetsPapers
這里大家有更多的興趣可以直接去看對應的paper,我接下來(lái)會(huì )盡我所能描述下infogan和AC-GAN這兩塊的內容
InfoGAN
InfoGAN是一種能夠學(xué)習disentangled representation的GAN(https://arxiv.org/pdf/1606.03657v1.pdf),何為disentangled representation?比如人臉數據集中有各種不同的屬性特點(diǎn),如臉部表情、是否帶眼睛、頭發(fā)的風(fēng)格眼珠的顏色等等,這些很明顯的相關(guān)表示, InfoGAN能夠在完全無(wú)監督信息(是否帶眼睛等等)下能夠學(xué)習出這些disentangled representation,而相對于傳統的GAN,只需修改loss來(lái)最大化GAN的input的noise(部分fixed的子集)和最終輸出之間的互信息。
原理
為了達到上面提到的效果,InfoGAN必須在input的noise來(lái)做一些文章,將noise vector劃分為兩部分:
z: 和原始的GAN input作用一致;
c: latent code,能夠在之后表示數據分布中的disentangled representation
那么如何從latent code中學(xué)到相應的disentangled representation呢? 在原始的GAN中,忽略了c這部分的影響,即GAN產(chǎn)生的數據分布滿(mǎn)足$P_{G}(x|C)=P(x)$,為了保證能夠利用c這部分信息, 作者提出這樣一個(gè)假設:c與generator的輸出相關(guān)程度應該很大,而在信息論中,兩個(gè)數據分布的相關(guān)程度即互信息, 即generator的輸出和input的c的$I(c;G(z,c))$應該會(huì )大。 所以,InfoGAN就變成如下的優(yōu)化問(wèn)題:
因為互信息的計算需要后驗概率的分布(下圖紅線(xiàn)部分),在實(shí)際中很難直接使用,因此,在實(shí)際訓練中一般不會(huì )直接最大化$I(c;G(z,c))$
這里作者采用和VAE類(lèi)似的方法,增加一個(gè)輔助的數據分布為后驗概率的low bound: 所以,這里互信息的計算如下:
這里相關(guān)的證明就不深入了,有興趣的可以去看看paper。
實(shí)驗
我寫(xiě)的一版基于TensorFlow的Info-GAN實(shí)現:Info-GANhttps://github.com/burness/tensorflow-101/tree/master/GAN/Info-GAN random的label信息,和對應生成的圖像:
不同random變量控制產(chǎn)生同一class下的不同輸出:
AC-GAN
AC-GAN即auxiliary classifier GAN,對應的paper:https://arxiv.org/abs/1610.09585, 如前面的示意圖中所示,AC-GAN的Discriminator中會(huì )輸出相應的class label的概率,然后更改loss fuction,增加class預測正確的概率, ac-gan是一個(gè)tensorflow相關(guān)的實(shí)現,基于作者自己開(kāi)發(fā)的sugartensor,感覺(jué)和paper里面在loss函數的定義上差異,看源碼的時(shí)候注意下,我這里有參考寫(xiě)了一個(gè)基于原生tensorflow的版本AC-GAN.
實(shí)驗
各位有興趣的可以拿代碼在其他的數據集上也跑一跑,AC-GAN能夠有效利用class label的信息,不僅可以在G時(shí)指定需要生成的image的label,同事該class label也能在Discriminator用來(lái)擴展loss函數,增加整個(gè)對抗網(wǎng)絡(luò )的性能。 random的label信息,和對應生成的圖像:
不同random變量控制產(chǎn)生同一class下的不同輸出:
Summary
照例總結一下,本文中,我基本介紹了下生成式模型方法的各個(gè)族系派別,到GAN的基本內容,到InfoGAN、AC-GAN,大部分的工作都來(lái)自于閱讀相關(guān)的paper,自己相關(guān)的工作就是 tensorflow下參考sugartensor的內容重現了InfoGAN、AC-GAN的相關(guān)內容。 當然,本人菜鳥(niǎo)一枚,難免有很多理解不到位的地方,寫(xiě)出來(lái)更多的是作為分享,讓更多人了解GAN這塊的內容,如果任何錯誤或不合適的地方,盡情在評論中指出,我們一起討論一起學(xué)習。另外我的所有相關(guān)的代碼都在github上:GAN,相信讀一下無(wú)論是對TensorFlow的理解還是GAN的理解都會(huì ) 有一些幫助,簡(jiǎn)單地參考mnist.py修改下可以很快的應用到你的數據集上,如果有小伙伴在其他數據集上做出有意思的實(shí)驗效果的,歡迎分享。
本文轉自全球人工智能,作者:UCloud應用創(chuàng )新部深度學(xué)習研發(fā)工程師Burness Duan
最后,記得關(guān)注微信公眾號:鎂客網(wǎng)(im2maker),更多干貨在等你!
硬科技產(chǎn)業(yè)媒體
關(guān)注技術(shù)驅動(dòng)創(chuàng )新
