《Adversarial Neural Machine Translation》阅读笔记

原文链接:arxiv.org/abs/1704.0693

文章简介:本篇文章主要提出如何将对抗网络机制(GAN)应用于神经网络翻译模型。过去的神经网络翻译模型中,以最大化人为翻译结果的似然作为目标来更新模型的参数,更准确地说,模型的loss是以ground truth为基准的交叉熵。在本文的模型中,目标则是最小化机器翻译的结果和人为翻译结果的差异。和过往的GAN的框架一样,该模型主要包含G网络和D网络。G网络是生成网络,对翻译任务而言,输入是源序列x,该子网络生成翻译结果y(目标序列)。D网络是判别网络,输入是源序列x和翻译结果y',该子网络判断y'来自于G网络还是人为翻译。模型的最终目标是让G网络生成足够欺骗D网络的翻译结果。在具体实现时,通过增强学习中的决策梯度来更新G网络参数(也包含以往以交叉熵为目标函数做梯度更新的方式,后文详细解释)。模型在English\rightarrowFrench 和 German\rightarrowEnglish 两个任务中都取得了比几个强baseline要好的结果。

模型:

  • G网络

G网络仍然采用的是基本的基于RNN的encoder-decoder框架。给定源序列x
作为输入,G网络生成目标语言翻译y'
。不妨将该网络学到的概率映射表示为G(y|x),按照该分布抽样出目标序列y'\sim G(y|x)

具体的,给定源序列x
和前面生成的部分序列y_{<t},那么当前时刻生成词y_t的概率是:


其中,r_t
是decoder在t时刻隐藏输出。这里g是循环神经元,可以是LSTM或者GRU。c_t
表示的是t时刻的context,通过attention机制计算得到,计算方式如下:

其中,a是一个前馈神经网络。h_i是RNN-encoder的隐层输出,其计算公式如下:

这一部分的内容和过往的模型没有什么区别,不了解的话建议去看一下较早的一篇论文Neural Machine Translation by Jointly Learning to Align and Translate


  • D网络

D网络

D网络接受的输入是source-target sentence pair (x,y)。D网络判断y是针对x的机器翻译还是人为翻译。句子对对被处理成类似图像的形式输入。具体的,将源序列x中的第i个词x_i和目标序列y中的第j个词y_j对应的词向量拼接起来:

由此得到一个三维的张量表达。再通过3\times 3window的卷积层和2\times 2window的maxPooling得到更抽象的特征表达。如此重复几次卷积和pooling操作,再将得到的特征表达送入多层感知机,由最后一层的sigmoid激活函数输出2分类的概率分布。

注:将源序列和目标序列中的词两两配对作为输入具有一定的合理性,能够让判别模型捕捉到源序列和目标序列词之间的关联信息。但是这种D网络的设计是否足够有效以及是否有改进的空间值得思考。

  • 梯度计算

D网络直接通过有监督的交叉熵计算梯度来更新参数。G网络则通过增强学习中的决策梯度来更新参数。具体的,G网络的目标是最小化如下的Loss函数:

求导后:

透过抽样y'
来近似上面的梯度并更新G网络参数:

在上面的公式中,可以将-\log(1-D(x,y'))看做reward。不同于seqGAN中在每一个时间步都计算一次reward(即固定前面和当前时间步已经生成的词,按概率继续多次抽样出完整的句子再计算一个平均reward),本文的模型仅仅计算按概率抽样生成一个完整句子后的reward。文中提到这么做会有方差偏大的问题,但是如果语料足够大,受影响的程度会被减小。

注:这里reward的设计还是显得比较简单。理想情况下,reward对于好的样例应该为正,对差的样例为负。个人理解得也不深,可以参考链接zhuanlan.zhihu.com/p/26

  • 实验和结果

在实验中的几个点还需要说明一下。
  1. 和其他的GAN模型一样。本文的GAN仍然需要warm start一下,即需要通过和以前一样用交叉熵作为目标函数进行预训练,将模型参数训练至一个相对合理的范围。
  2. 文中提到了一个很关键的训练trick。即在每一个训练batch中,按50%的概率抽出部分数据用决策梯度训练,另外的用MLE训练。文中说,这么做是为了让MLE充当一个正则的作用,以控制增强学习中抽样引起的高方差带来的消极影响。

在两个翻译任务上的结果截图如下,这里不再做分析:


应该说都是取得了state-of-art的结果。文中另外探索了训练速率对G网络和D网络的影响


从结果来看,训练速率对G网络的影响会比较大,大的训练速率会使得结果出现比较大的上下浮动,小速率又会使得模型不容易找到最优点。而D网络则显得比较robust。

总结:

这应该是第一篇将GAN用于机器翻译任务中的文章。文中D网络的设计和训练trick都很有创意。但是个人感觉,GAN实际带来的好处体现得还不是很明显,一些细节还存在可以思考改进的空间。

编辑于 2017-06-02

文章被以下专栏收录

    「明光村职业技术学院」一群努力磕盐小伙伴的日常Paper阅读笔记分享。欢迎关注机器学习、深度学习、自然语言处理等方向的同好们共同探讨,共同学习!