生成对抗模型是个非常有趣的神经网络模型,最近因为比赛需要用到对抗模型,所以就稍微阅读了一下Goodfellow大神的论文《Generative Adversarial Nets》,顺便把一些理解以及笔记写下来供以后查阅。

模型简介

长久以来,基于深度学习的判别模型取得了巨大的成功,例如CNN/MLP等技术,但是基于深度学习的生成模型却一直不温不火,究其原因,作者认为主要难点在于难以解决极大似然估计和相关策略中概率近似的复杂计算问题。为此,作者提出了GAN的生成对抗网络框架以回避这些问题。
GAN模型主要由两个网络组成,分别是Generator生成器生成对抗模型——Generative Adversarial Nets - 图1以及Discriminator判别器生成对抗模型——Generative Adversarial Nets - 图2。生成器的任务是模拟数据的分布,生成仿真数据,同时尽可能地欺骗判别器。而判别器的任务是尽可能分辨真实数据与生成器生成的仿真数据,这种生成——判别之间的博弈被称为对抗。

模型定义

假定数据生成对抗模型——Generative Adversarial Nets - 图3服从分布生成对抗模型——Generative Adversarial Nets - 图4,为了模拟数据的真实分布,我们需要利用生成器生成对抗模型——Generative Adversarial Nets - 图5,并预先定义一个已知的噪声分布生成对抗模型——Generative Adversarial Nets - 图6,实际应用中,噪声一般采用高斯噪声或者均匀分布的噪声。生成器生成对抗模型——Generative Adversarial Nets - 图7输入为噪声生成对抗模型——Generative Adversarial Nets - 图8,输出为模拟的数据生成对抗模型——Generative Adversarial Nets - 图9生成对抗模型——Generative Adversarial Nets - 图10生成对抗模型——Generative Adversarial Nets - 图11的参数。判别器生成对抗模型——Generative Adversarial Nets - 图12就是一个二分类器,其作用是判别输入数据是否为真实数据的概率,即生成对抗模型——Generative Adversarial Nets - 图13,其中生成对抗模型——Generative Adversarial Nets - 图14为判别器的参数。生成器与判别器的最简单形式为多层感知机。
前面说过,该框架下,判别器生成对抗模型——Generative Adversarial Nets - 图15需要尽可能让生成对抗模型——Generative Adversarial Nets - 图16接近于1,生成对抗模型——Generative Adversarial Nets - 图17接近于0。而生成器的任务是尽可能让生成对抗模型——Generative Adversarial Nets - 图18接近于1,用min-max博弈来写的话就是:
生成对抗模型——Generative Adversarial Nets - 图19
至于为什么要用log代替原始概率,作者并没有给出解释,可能是受到交叉熵函数影响吧,实际上损失函数可以采用别的形式,不必拘泥于log。

训练步骤

在论文中作者采用的训练方法是,连续k步对判别器生成对抗模型——Generative Adversarial Nets - 图20进行优化,随后1步对生成器生成对抗模型——Generative Adversarial Nets - 图21进行优化。同时,在实际训练中,由于刚开始的生成器生成对抗模型——Generative Adversarial Nets - 图22效果很差,对于判别器来说非常容易对生成器的数据进行判别,所以可以将最小化生成对抗模型——Generative Adversarial Nets - 图23)改为最大化生成对抗模型——Generative Adversarial Nets - 图24,最终收敛效果一致,但是这种改动使得训练前期的loss更为有效(即能更快地优化生成器参数),原因在于,假设生成对抗模型——Generative Adversarial Nets - 图25的参数为生成对抗模型——Generative Adversarial Nets - 图26,那么梯度值可以写为:
生成对抗模型——Generative Adversarial Nets - 图27
由于前期判别器很强势,所以可以轻松识别生成器数据,故而生成对抗模型——Generative Adversarial Nets - 图28,但是如果采用生成对抗模型——Generative Adversarial Nets - 图29,则梯度可以写为:
生成对抗模型——Generative Adversarial Nets - 图30
此时梯度值很大,可以有效更新生成器参数加快模型收敛,具体的训练步骤如下:

image.png
需要注意,在更新判别器参数时需要增加损失函数值,故而模型参数需要加上梯度值,在更新生成器时则相反。当然在实际写代码时可以更灵活一些,改变损失函数形式,分别最小化两个不一样的损失函数即可。

理论分析

简单地复读一下作者在论文中给出的理论分析,当生成器生成对抗模型——Generative Adversarial Nets - 图32不变时,最佳判别器为:
生成对抗模型——Generative Adversarial Nets - 图33
由于:
生成对抗模型——Generative Adversarial Nets - 图34
其中生成对抗模型——Generative Adversarial Nets - 图35是生成器产生样本生成对抗模型——Generative Adversarial Nets - 图36对应的概率分布,对上述式子求导并等0,可以得到:
生成对抗模型——Generative Adversarial Nets - 图37
在理想情况下,最终会收敛到生成对抗模型——Generative Adversarial Nets - 图38,即生成器完美拟合原始数据分布。

写在最后

在那么多模型中,生成对抗模型的确是一个特别有意思的模型,可以用来做很多有趣的事情。以前实习的时候听到别人对GAN的评价就是搞的东西很有意思,但是没有软用(商业应用前景)。如果这次比赛拿名次的话会介绍一下我们队在比赛中用到的对抗模型,以后有机会也可能用GAN来做一些有意思的项目。