1 生成器

1.1 与传统模型的不同

传统模型通常只是读入一组特征,给出一个输出。然而生成器除了特征x外,还需要输入一个z(其中z需要服从某一简单分布)。z的分布必须可以用函数进行表示,因为只有知道了分布规律,才可以对z进行采样。
生成器输出的y的分布规律是未知的(复杂分布)。
image.png

1.2 为什么需要不确定的z?

  1. 设想一个吃豆人游戏,我们现在去训练一个网络,使其根据先前的几帧预测下面的一帧。<br />![image.png](https://cdn.nlark.com/yuque/0/2021/png/1295225/1629176488385-8abf7ad0-556d-46b2-a182-bcfbeebf2485.png#clientId=ubc6da337-a0e3-4&from=paste&height=635&id=ufd8d77be&margin=%5Bobject%20Object%5D&name=image.png&originHeight=635&originWidth=1031&originalType=binary&ratio=1&size=241834&status=done&style=none&taskId=ub3215f4a-565d-44b8-aa67-f0c32c48798&width=1031)<br /> 然而,如果使用普通网络,小精灵在岔路口通常会分裂成两个,因为通过训练资料模型会学习到无论往什么方向走都是可行的。又因为模型的输出是不具有随机性的,因此就造成了这种后果。

加入了不确定的变量z后,z可以代表在某一时刻所做出的选择,这样一来就可以解决这个问题:
image.png

特别地,当我们要求模型具有一定的创造力时(一个同样的输入有不同输出),引入服从简单分布的变量z是很有必要的。
image.png

2 生成对抗网络

2.1 结构

最终需要投入使用的模型实际上就是生成器。通过输入服从简单分布的向量z,我们可以得到最终服从复杂分布的更高维的向量,从而建立从简单分布到复杂分布的映射关系。
image.png
接下来还有判别器的概念,判别器用于给生成器输出的结果打分,越贴近真实结果,打分就越高,vice versa
image.png

2.2 思想

可以拿枯叶蝶进化的例子做一个类比。蝴蝶为了躲避鸟的捕杀,一步一步向枯叶的样子进化。相应的,鸟为了将蝴蝶与叶子区分开来,也进行着一步一步的进化。比如,一开始鸟只会通过颜色进行判别,后来学会观察是否有叶脉进行判别,等等。这个进化的过程是一个“对抗”的过程。实际上,蝴蝶就类似于生成器,而鸟就类似于判别器。
image.png
在GAN中,生成器在不断生成更好的图片企图骗过判别器,相应地,判别器也在不断提高自己的判别技术来将生成的图像与真实的图像区别开。
image.png

2.3 算法

首先初始化生成器与判别器。
在每一次迭代中:

  1. 固定生成器不变,随后随机采样一组简单分布的z,使其生成一组数据。此时从数据集也采样一组数据。训练判别器,使其能够生成的数据与真实数据。
  2. 固定判别器不变,训练生成器使其能够骗过判别器。

    如图所示:
    image.png

    3 GAN理论

    3.1 生成器训练目标

    事实上,我们就是为了让模型输出的复杂分布与真实数据的分布尽可能地接近。
    以一维特征作为例子,我们需要将一个服从简单分布(假设正态)的一维向量,丢进network中,然后映射成为复杂分布,而这个复杂分布越贴近真实的分布就越好。
    所以,优化生成器中参数的目标如下式所示。其中,divergence表示生成器产生的分布与真实分布之间的差异。
    Generative Adversarial Network - 图9
    image.png
    一个很大的问题是,我们并不知道这个divergence应该怎么求。GAN的一个很大的好处是只要我们会从数据集中进行sample,就可以计算divergence,而不需要知道数据分布的公式到底长什么样子。

    3.2 判别器训练目标

    判别器训练的目标实际上就是让生成器输出sample出来的结果打分尽可能低,真实数据尽可能高。
    训练目标:
    Generative Adversarial Network - 图11
    D优化的目标函数(越大越好,与loss function相反)为:
    Generative Adversarial Network - 图12
    image.png
    此外还有一点,那就是Generative Adversarial Network - 图14这个值(也就是真实数据与生成器生成数据差异的最大值)是与真是数据的JS Divergence息息相关的(原论文有证明)。
    所以,要计算真实数据与生成数据的div散度,我们可以训练一下判别器,看看优化目标函数的最大值可以到多大就ok了。
    从直观上理解:
    image.png

    3.3 进一步推导

    刚才提到,Generative Adversarial Network - 图16与divergence息息相关,也就是:
    Generative Adversarial Network - 图17
    那么,由Generative Adversarial Network - 图18就可以得到:
    Generative Adversarial Network - 图19
    Generative Adversarial Network - 图20
    所以,先训练判别器,再训练生成器。
    image.png