WGAN-GP方法介绍

原文标题:Improved Training of Wasserstein GANs

原文链接[1704.00028] Improved Training of Wasserstein GANs

背景介绍

训练不稳定是GAN常见的一个问题。虽然WGAN在稳定训练方面有了比较好的进步,但是有时也只能生成较差的样本,并且有时候也比较难收敛。原因在于:WGAN采用了权重修剪(weight clipping)策略来强行满足critic上的Lipschitz约束,这将导致训练过程产生一些不希望的行为。本文提出了另一种截断修剪的策略-gradient penalty,即惩罚critic相对于其输入(由随机噪声z生成的图片,即fake image)的梯度的norm。就是这么一个简单的改进,能使WGAN的训练变得更加稳定,并且取得更高质量的生成效果。

注意:GAN之前的D网络都叫discriminator,但是由于这里不是做分类任务,WGAN作者觉得叫discriminator不太合适,于是将其叫为critic。

方法介绍

介绍WGAN-GP方法前,先简单介绍一下WGAN,WGAN的损失函数如下:

公式1

这里需要注意的是,WGAN的提出是作者分析了一堆统计度量(KL散度,JS散度,TV距离,W距离等)后,得出Wasserstein距离(下简称W距离)最适合GAN的训练。按理说WGAN的损失函数就是一个分布到另一个分布的W距离,如公式2所示:

公式2

但是大家可以看到,公式2中有个下确界符号inf,让人看着有点懵。不过没关系,Kantorovich-Rubinstein duality理论(该理论太复杂,这里不介绍)告诉我们:当critic满足Lipschitz连续条件时,公式2可以转化为公式1的形式。直观上,公式2跟神经网络半毛钱关系没有,给出这个公式我们也不会优化,但是公式1一看就是一个标准的神经网络的损失函数(把x和z当做输入,G和W当做网络的两部分)。公式2转化为公式1的形式后,就可以用神经网络中常用的梯度下降法去优化了。

注意这里有个名词:“Lipschitz连续”,大家不要被这个牛逼的名字吓到,其概念其实很简单,意思就是定义域内每点的梯度恒定不超过某个常数(常数是多少无所谓,不是无穷就行)。那么怎么来保证critic的Lipschiz连续呢?作者用的方法极其简单,就是weight clip策略。weight clip策略的意思是:限制神经网络 f_{w} 的所有参数w不超过某个范围[-c, c](比如[-0.01, 0.01]),即大于c的置为c,小于-c的置为-c。为什么这样做能保证Lipschiz连续(定义域内每点的梯度不超过某个常数)呢?因为critic相对于其输入的导数是个含w的表达式,w不超过某个范围,那critic相对于其输入的梯度一定也不会超过某个范围,Lipschiz连续条件得以满足。

这么粗暴的做法WGAN作者也是觉得不妥的,但是暂时没有想到更好的办法,只能用这个简单的方法了。WGAN-GP就是从这点入手做文章。

其实,WGAN-GP方法的作者也是普通人,一开始想到的也是很普通的方法,比如把weight clipping这么粗暴的方法改为L2 norm clip,做权重的归一化等。然并卵,这些方法的效果跟带weight clipping的WGAN效果没啥区别。作者也尝试了batch normalization的方法,但是发现当critic太深时,WGAN难以收敛。于是,才有了WGAN-GP方法。WGAN-GP的目标函数如下所示:

公式3

可以看到,WGAN-GP相对于WGAN的改进很小,除了增加了一个正则项,其他部分都和WGAN一样。 这个正则项就是WGAN-GP中GP(gradient penalty),即梯度约束。这个约束的意思是:critic相对于原始输入的梯度的L2范数要约束在1附近(双边约束)。为什么这个约束是合理的,这里作者给了一个命题,并且在文章补充材料中给出了证明,这个证明大家有兴趣可以自己去看,这里只想简单介绍一下这个命题。这个命题说的是在最优的优化路径上(把生成分布推向真实分布的“道路”上),critic函数对其输入的梯度值恒定为1。有了这个知识后,我们可以像搞传统机器学习一样,将这个知识加入到目标函数中,以学习到更好的模型。

这里需要说明一下,WGAN-GP作者加的这个约束能保证critic也是一个Lipschiz连续函数。因为critic对任意输入x的梯度都是一个含参数w的表达式,而这个梯度的L2 norm大小约束在1附近,那w也不超过某个常数。因而从保证Lipschiz连续的条件上,GP的作用跟weight clip是一样的。

WGAN-GP具体算法步骤如下:

可以看出跟WGAN不同的主要有几处:1)用gradient penalty取代weight clipping;2)在生成图像上增加高斯噪声;3)优化器用Adam取代RMSProp。

这里需要注意的是,这个GP的引入,跟一般GAN、WGAN中通常需要加的Batch Normalization会起冲突。因为这个GP要求critic的一个输入对应一个输出,但是BN会将一个批次中的样本进行归一化,BN是一批输入对应一批输出,因而用BN后无法正确求出critic对于每个输入样本的梯度。

实验介绍

作者做了很多实验,主要就是为了说明WGAN中GP用于保证Lipschiz连续的方式要比weight clip好,因而能稳定训练,并生成质量比较好的图像。具体实验结论如下:

1)WGAN中的weight clip策略,会导致学到的绝大部分weight趋近于两个极端(-c和c),但是WGAN-GP学到的梯度是均匀分布在某个区间的。见图1。

图1

2)当critic选择的是比较深的网络时,WGAN中的c值不管怎么选取,都容易出现梯度爆炸或者梯度消失问题。见图1。

图2

3)采用weight clip的策略训练出的critic无法捕获数据分布的高阶矩信息。比如图2中第二行第一列的图中,WGAN-GP生成数据的critic值基本都分布在8个高斯附近,这和输入的样本信息(8个高斯)是一致的。但是第一行第一列中,WGAN-GP生成数据的critic值就没有如此优良特性。

4)WGAN-GP比WGAN效果要好。作者在cifar10上做了对比实验,结果如图3所示。在同样实验设置下,WGAN-GP结果要明显比WGAN好,跟DCGAN差不多,但是训练要比DCGAN稳定。

图3

5)其他GAN当G或者D改变,或者加不同激活函数时,效果差别很大,有的会训练不好,有的会出现mode collapse。但是,WGAN-GP对于各种不同的结构效果都很好。实验设置和相应实验结果如图4所示。

图4

5)WGAN-GP的loss曲线是有意义的。WGAN文章中介绍到WGAN的loss是和其样本生成质量相关的,即loss越小,生成样本质量越好。WGAN-GP也保持了这个特性。不仅如此,WGAN-GP的loss还能反映出过拟合的情况。如图5所示。

图5

总结

本文提出了一种梯度惩罚策略,来取代WGAN中的weight clipping策略,从而使WGAN的训练变得更加稳定,生成的图像质量更好。个人认为WGAN-GP最好的性质在于不用太关注网络结构的设计,无论采用什么样的结构都能训练得比较好。

编辑于 2018-12-25