GAN基本结构

GAN全称对抗生成网络,顾名思义是生成模型的一种,而他的训练则是处于一种对抗博弈状态中的。GAN的主要结构包括一个生成器G(Generator)和一个判别器D(Discriminator)。以下通过一个简单的例子展示了GAN的结构。

GAN生成对抗网络 - 图1

假设我们现在拥有大量的手写数字的数据集,我们希望通过GAN生成一些能够以假乱真的手写字图片。主要由如下两个部分组成:

  1. 定义一个模型来作为生成器(图三中蓝色部分),能够输入一个向量,输出手写数字大小的像素图像。
  2. 定义一个分类器来作为判别器(图三中红色部分)用来判别图片是真的还是假的(或者说是来自数据集中的还是生成器中生成的),输入为手写图片,输出为判别图片的标签。

这里生成器输入的向量我们将其视为携带输出的某些信息,在不加限制的情况下它们没有任何语义含义,可以为随机噪声。
GAN的训练过程如下:

  1. 初始化判别器GAN生成对抗网络 - 图2和生成器GAN生成对抗网络 - 图3的参数;
  2. 从真实数据集中采样m个真实样本GAN生成对抗网络 - 图4,从给定噪声分布中采样m个噪声样本GAN生成对抗网络 - 图5,将它们输入生成器GAN生成对抗网络 - 图6以获取m个生成样本GAN生成对抗网络 - 图7。固定生成器GAN生成对抗网络 - 图8,训练判别器GAN生成对抗网络 - 图9以区分真实样本与生成的假样本。(当然一开始假样本都是随机噪声)
  3. 循环k次迭代更新GAN生成对抗网络 - 图10后,使用较小的学习率来更新一次GAN生成对抗网络 - 图11,训练生成器使其尽可能减小生成样本与真是样本之间的差距。(由于GAN生成对抗网络 - 图12的训练依赖于GAN生成对抗网络 - 图13GAN生成对抗网络 - 图14的准确度非常重要,因此一般GAN生成对抗网络 - 图15的训练密度比GAN生成对抗网络 - 图16高,即k>1)
  4. 多次迭代更新后,理想状况是:GAN生成对抗网络 - 图17生成的假样本能够以假乱真,GAN生成对抗网络 - 图18无法分辨出样本的真假。