Tensorflow源码解读(一):Attention
Seq2Seq模型

Tensorflow源码解读(一):Attention Seq2Seq模型

Tensorflow版本:r0.12
github源码
ps: 0.12和1.0,1.1, 1.2版本的代码基本一致,本文对更高版本也具有参考价值。

Seq2Seq模型是机器翻译,对话生成等任务里经典的模型,attention机制也是在2016年刷爆了各种NLP任务,这两者都是很值得深入研究掌握的模型。本文要分享的是Tensorflow官方例子,翻译模型里的embedding_attention_seq2seq函数源码解读。文章参考了另一篇博客[1]和官方github源码,attention部分的公式和推导涉及了源码参考的论文[2]。

tf.nn.seq2seq文件共实现了5个seq2seq函数,因为本文重点讲解最后一个,所以前4个简要介绍一下。

  • basic_rnn_seq2seq:最简单版本,输入和输出都是embedding的形式;最后一步的state vector作为decoder的initial state;encoder和decoder用相同的RNN cell, 但不共享权值参数;
  • tied_rnn_seq2seq:同1,但是encoder和decoder共享权值参数
  • embedding_rnn_seq2seq:同1,但输入和输出改为id的形式,函数会在内部创建分别用于encoder和decoder的embedding matrix
  • embedding_tied_rnn_seq2seq:同2,但输入和输出改为id形式,函数会在内部创建分别用于encoder和decoder的embedding matrix
  • embedding_attention_seq2seq:同3,但多了attention机制

下面进入正题!

tf.nn.seq2seq.embedding_attention_seq2seq

# T代表time_steps, 时序长度
def embedding_attention_seq2seq(encoder_inputs,  # [T, batch_size] 
                                decoder_inputs,  # [T, batch_size]
                                cell,
                                num_encoder_symbols,
                                num_decoder_symbols,
                                embedding_size,
                                num_heads=1,      # attention的抽头数量
                                output_projection=None, #decoder的投影矩阵
                                feed_previous=False,
                                dtype=None,
                                scope=None,
                                initial_state_attention=False):

参数

Input

  • encoder_inputs:encoder的输入,int32型 id tensor list
  • decoder_inputs:decoder的输入,int32型id tensor list
  • cell: RNN_Cell的实例
  • num_encoder_symbols, num_decoder_symbols: 分别是编码和解码的符号数,即词表大小
  • embedding_size: 词向量的维度
  • num_heads:attention的抽头数量,一个抽头算一种加权求和方式,后面会进一步介绍
  • output_projection:decoder的output向量投影到词表空间时,用到的投影矩阵和偏置项(W, B);W的shape是[output_size, num_decoder_symbols],B的shape是[num_decoder_symbols];若此参数存在且feed_previous=True,上一个decoder的输出先乘W再加上B作为下一个decoder的输入
  • feed_previous:若为True, 只有第一个decoder的输入(“GO"符号)有用,所有的decoder输入都依赖于上一步的输出;一般在测试时用(当然源码也提到,可以在训练时用于模拟测试的环境,比如Scheduled Sampling
  • initial_state_attention: 默认为False, 初始的attention是零;若为True,将从initial state和attention states开始attention

Output

  • (outputs, state) tuple pair,outputs是 2D Tensors list, 每个Tensor的shape是[batch_size, cell.state_size];state是 最后一个时间步,decoder cell的state,shape是[batch_size, cell.state_size]

Encoder

  • 创建了一个embedding matrix.
  • 计算encoder的output和state
  • 生成attention states,用于计算attention
encoder_cell = rnn_cell.EmbeddingWrapper(      
        cell, embedding_classes=num_encoder_symbols,
        embedding_size=embedding_size)
    encoder_outputs, encoder_state = rnn.rnn(
        encoder_cell, encoder_inputs, dtype=dtype) #  [T,batch_size,size]

    top_states = [array_ops.reshape(e, [-1, 1, cell.output_size])
                  for e in encoder_outputs]    # T * [batch_size, 1, size]
    attention_states = array_ops.concat(1, top_states) # [batch_size,T,size]

上面的EmbeddingWrapper, 是RNNCell的前面加一层embedding,作为encoder_cell, input就可以是word的id。

class EmbeddingWrapper(RNNCell):
  def __init__(self, cell, embedding_classes, embedding_size, initializer=None):
  def __call__(self, inputs, state, scope=None):
  #生成embedding矩阵[embedding_classes,embedding_size]
  #inputs: [batch_size, 1]
  #return : (output, state)

Decoder

  • 生成decoder的cell,通过OutputProjectionWrapper类对输入参数中的cell实例包装实现
# Decoder.
    output_size = None
    if output_projection is None:
      cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
      output_size = num_decoder_symbols
    if isinstance(feed_previous, bool):
      return embedding_attention_decoder(
          ...
      )

上面的OutputProjectionWrapper将输出映射成想要的维度

class OutputProjectionWrapper(RNNCell):
  def __init__(self, cell, output_size): # output_size:映射后的size
  def __call__(self, inputs, state, scope=None):
  #init 返回一个带output projection的 rnn_cell

接着对embedding_attention_decoder一探究竟:

def embedding_attention_decoder(decoder_inputs,
                                initial_state,
                                attention_states,
                                cell,
                                num_symbols,
                                embedding_size,
                                num_heads=1,
                                output_size=None,
                                output_projection=None,
                                feed_previous=False,
                                update_embedding_for_previous=True,
                                dtype=None,
                                scope=None,
                                initial_state_attention=False):
# 核心代码
    embedding = variable_scope.get_variable("embedding",
                                            [num_symbols, embedding_size])
    loop_function = _extract_argmax_and_embed(
        embedding, output_projection,
        update_embedding_for_previous) if feed_previous else None
    emb_inp = [
        embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs]
    # T * [batch_size, embedding_size]
    return attention_decoder(
        emb_inp,
        initial_state,
        attention_states,
        cell,
        output_size=output_size,
        num_heads=num_heads,
        loop_function=loop_function,
        initial_state_attention=initial_state_attention)

简要的说,embedding_attention_decoder的代码,第一步创建了解码用的embedding; 第二步创建了一个循环函数loop_function,用于将上一步的输出映射到词表空间,输出一个word embedding作为下一步的输入;最后是我们最关注的attention_decoder部分完成解码工作!

tf.nn.attention_decoder

论文涉及三个公式:

encoder输出的隐层状态(h_{1},...,h_{T_{A}}),decoder的隐层状态(d_{1},...,d_{T_{B}})v^{T}W^{'}_{1}W^{'}_{2}是模型要学的参数。所谓的attention,就是在每个解码的时间步,对encoder的隐层状态进行加权求和,针对不同信息进行不同程度的注意力。那么我们的重点就是求出不同隐层状态对应的权重。源码中的attention机制里是最常见的一种,可以分为三步走:(1)通过当前隐层状态(d_{t})和关注的隐层状态(h_{i})求出对应权重u^{t}_{i};(2)softmax归一化为概率;(3)作为加权系数对不同隐层状态求和,得到一个的信息向量d^{'}_{t}。后续的d^{'}_{t}使用会因为具体任务有所差别。

上面的a^{t}_{i}含义是第t个时间步,对h_{i}的加权系数。

下面上代码的时刻!

def attention_decoder(decoder_inputs,  #T * [batch_size, input_size]
                      initial_state,   #[batch_size, cell.states]
                      attention_states,#[batch_size, attn_length , attn_size]
                      cell,
                      output_size=None,
                      num_heads=1,
                      loop_function=None,
                      dtype=None,
                      scope=None,
                      initial_state_attention=False):

对于num_heads参数,还记得当初留的坑么:) 我们知道,attention就是对信息的加权求和,一个attention head对应了一种加权求和方式,这个参数定义了用多少个attention head去加权求和,所以公式三可以进一步表述为\sum^{num\_heads}_{j=1}\sum^{T_{A}}_{i=1}a_{i,j}h_{i}

  • W_{1}*h_{i}用的是卷积的方式实现,返回的tensor的形状是[batch_size, attn_length, 1, attention_vec_size]
# To calculate W1 * h_t we use a 1-by-1 convolution
    hidden = array_ops.reshape(
        attention_states, [-1, attn_length, 1, attn_size])
    hidden_features = []
    v = []
    attention_vec_size = attn_size  # Size of query vectors for attention.
    for a in xrange(num_heads):
      k = variable_scope.get_variable("AttnW_%d" % a,
                                      [1, 1, attn_size, attention_vec_size])
      hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME"))
      v.append(
          variable_scope.get_variable("AttnV_%d" % a, [attention_vec_size]))
  • W_{2}*d_{t},此项是通过下面的线性映射函数linear实现
for a in xrange(num_heads):
        with variable_scope.variable_scope("Attention_%d" % a):
          # query对应当前隐层状态d_t
          y = linear(query, attention_vec_size, True)
          y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size])
          # 计算u_t
          s = math_ops.reduce_sum(
              v[a] * math_ops.tanh(hidden_features[a] + y), [2, 3])
          a = nn_ops.softmax(s)
          # 计算 attention-weighted vector d.
          d = math_ops.reduce_sum(
              array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden,
              [1, 2])
          ds.append(array_ops.reshape(d, [-1, attn_size]))

到这里,embedding_attention_seq2seq的核心代码都已经解读完毕了。在实际的运用,可以根据需求灵活使用各个函数,特别是attention_decoder函数。相信坚持阅读下来的小伙伴们,能对这个API有更深刻的认识:)


参考文献:

[1] tensorflow学习笔记(十一):seq2seq Model相关接口介绍

[2] Grammar as a Foreign Language

编辑于 2017-07-13

文章被以下专栏收录