PaperWeekly
首发于PaperWeekly

对Attention机制的理解

多图预警。。。。。。。。。。。。。。。。。。。。。。。。。。。。

本文来自CS224n的Machine Translation, Seq2Seq and Attention一章的slides,强烈推荐

加上一些自己的理解,slides用图宏观上理解,想要深入,可以看官方的notes


Neural Machine Translation

机器翻译是seq2seq的一个主要使用领域

机器翻译早期的时候都是基于规则的系统,到了1990s-2010s年代,使用的是基于统计的机器翻译系统Statistical Machine Translation(SMT),而发展到2014年,随着神经机器翻译 Neural Machine Translation(NMT)的出现,机器翻译领域得到了爆炸性的发展

NMT是使用神经网络做机器翻译的方法,而这个神经网络的架构叫做sequence-to-sequence(aka seq2seq)

整个seq2seq架构如下图,其中包括两个RNNs结构,一个叫做Encoder RNN 编码器,一个叫做Decoder RNN 解码器,编码器负责把源序列编码成向量,解码器是一个语言模型,负责根据编码的信息生成目标序列:


seq2seq是一种端到端(end-to-end)的系统,所以训练seq2seq是把编码器解码器看作一个整体去训练,整个训练过程如下图,细的箭头代表前向传播的过程,粗的蓝色箭头代表反向传播的过程


Beam Search

在这个训练过程中有一个点,我们怎样在decode过程中,对于预测结果,我们取最大概率的词作为输出,这样的方法叫做greedy search decoding 贪婪搜索解码,即每一个时间步都取概率最大的那个值,但是这个方法有一个问题,就是没办法撤销决定,什么意思呢?

例如我们翻译真实句子是"he hit me with a pie",当我们的模型预测target sentence时,万一预测到第三个时间步就错了"he hit a",那么这是系统是没有机会纠正这个错误的,所以贪婪搜索解码太严格了,我们需要有一定的容错性,去在多个结果中选择最好的

那么自然想到在另外一种极端方向的做法就是Exhaustive search decoding,最全搜索解码,就是我们每一个时间步,都去计算V种可能,V是词库的大小,然后综合各个时间步就可以计算出,每一种可能的序列组合,然后看哪一个序列概率最大,这种方法很明显,计算太昂贵了

这种来说,比较好的是使用Beam search decoding,核心思想是,在每一个时间步,都去计算k个最可能的结果,k叫做beam size,实际中比较常使用5到10

Beam search不一定保证能找到最优的方案,但是比exhaustive search要有效率太多

如图所示,是当k=2时,Beam search的一个例子:

整个过程是这样的:

step1:

从<START>开始,计算下一个可能出现的字符的概率,取top2,he 和 I

step2:

对he计算下一个可能出现的字符的概率,取top2,hit 和 struck

同样的,对 I 计算下一个可能出现的字符的概率,取top2,was 和 got

这时产生了四种组合,保留top2,因为hit, was 概率最大,所以保留这两个,再去做同样的事

step3:

对hit计算,取top2,产生a 和 me; 对was 计算,取top2,产生hit 和 struck

保留a 和 me

step4:

对a计算,产生tart, pie;对me计算,产生with, on

保留pie, with

step5:

对pie计算,产生in, with;对with计算,产生a, one

保留a, one

step6:

对a计算,产生pie, with;对one计算,产生pie, tart

保留pie

step7:

从pie,往前backtrack,得到完整的句子

你可能会有疑惑,什么时候beam search停下来呢?

在greedy search中,我们通常有开始符和结束符,如“<START> he hit me with a pie <END>”

但在beam search中,不同的组合可能在不同的时间步产生<END>结束符,所以当某一个组合产生结束符时,我们就认为它结束了,先放到一边,继续探索其他的可能性

我们通常

  • 当达到T时间步,就停止,T是提前定义的,或者
  • 我们已经有了至少n种组合,就停止,n是提前定义的

最终我们有多个输出序列,怎样选分数最高的?

这样来选有一个问题,长度越长的序列,分数就越低,所以用长度来标准化一下:

NMT这么好,相比SMT也还是有一些缺点的:

  • 可解释性差,难以debug
  • 难控制,比如没法轻松的加进去一些特殊规则,还有一些安全问题

Evaluate MT

怎样衡量机器翻译任务的好坏,我们使用BLEU (Bilingual Evaluation Understudy)

BLEU把机器翻译的句子,和一个或者多个人工翻译的句子作比较,计算相似度,

这个相似度是基于n-gram,再加上对翻译太短的惩罚

BLEU是有用的但是不完美的,因为

  • 一个句子有很多种正确的翻译方式
  • 一个好的翻译句子或许会因为和人工翻译有很低的n-gram重合度,而得到一个低的BLEU分数

BLEU还没有研究过,只介绍一些概念。。。


重点来了。。。。。。。。。。。。。。。。。。Attention!

Attention in diagram

NMT研究引领了最近NLP深度学习的许多创新,2019年有一项对seq2seq的改进是至关重要的,他就是attention机制

首先回顾一下seq2seq的结构:

这个结构的问题在于,编码器需要把整个Source sentence的信息全部编码起来,这是seq2seq架构的瓶颈所在,attention机制就是解决这个瓶颈的一种方法

Attention机制的核心想法就是:在解码器的每一个时间步,都和编码器直接连接,然后只关注source sentence中的特定的一部分

先用图来看一下attention机制在做什么,不考虑数学公式:

Step1:

在解码阶段,第一个时间步开始,把hidden layer生成的hidden state,与编码器的第一个时间步的hidden state做点乘,产生一个数值,叫做attention scores

接下来对编码器每一个时间步都这样操作,每一个时间步都产生一个attention scores:


Step2:

对刚才产生的scores,通过softmax,转变成概率分布,再将这个概率分布,与编码器的所有hidden states,做一个加权求和,得到一个attention output,很容易看出来,这个输出中,哪一个hidden state的权重最大,就包含了越多它的信息(可以看出解码器的第一个时间步,对编码器的第一个时间步‘关注’最多)


Step3:

把这个attention output和解码器的第一个时间步的hidden state合并起来,用来计算 \hat y_1 ,​ \hat y_1 中概率最大的就是预测的值,这里举例是he

Step4:

按照前面所说的第一个时间步的做法,对解码器的每一个时间步都这样去做,得到最终的输出,概括来说,就是解码器的每一个时间步,都会去看一下编码器的全文,挑选当前时间步最“感兴趣”的部分,给予大的权重,基于“感兴趣”的部分,得出自己的预测

Attention in equations

假设编码器的hidden states分别为 h_1, ..., h_N

在解码器的第一个时间步,解码器的hidden state为 s_1

我们通过点乘得到当前时间步的attention scores:

e^1 = [s^T_1h_1,...,s^T_1h_N]

再通过softmax转换成attention distribution(是一个概率分布,和为1):

\alpha^1 = softmax(e^1)

然后用这个概率分布去和编码器的所有hidden states加权求和,得到attention output:

a_1 = \sum_{}^{}\alpha_i^1h_i

最后把attention output和解码器的hidden state合并起来得到 [a_1;s_1] ,然后去执行不用attention机制的seq2seq一样的做法就好了

这样就得到了解码器的第一个时间步的输出,后面以此类推


编辑于 2019-12-25

文章被以下专栏收录

    PaperWeekly是一个推荐、解读、讨论和报道人工智能前沿论文成果的学术平台,致力于让国内外优秀科研工作得到更为广泛的传播和认可。