RNN 中学习长期依赖的三种机制

RNN 中学习长期依赖的三种机制

本文是个人长期以来积累的一些模糊的想法,基本上不涉及具体公式,有些部分可能没有确凿的证据证实,仅仅是个人观点而已,抛出来供各位一起思考。


理论上 Simple RNN(或者叫 Vanilla RNN)就已经是图灵完备的了,能够模拟任何程序。但是实践中总是不尽如人意,因为存在性(存在一组模型设置能达到某个目的)和可学习性(真的通过训练能够找到这组参数设置)是两回事。


目前大约有三类机制解决长期依赖的学习问题,分别是门机制、跨尺度连接和特殊初始化(及其维持)


【门机制】

  • 代表作
  • 解释
    • 这类 RNN Cell 大家应该很熟悉了,其主要特点是用门控制信息流动,隐层状态采用加性更新,不做非线性变换。具体的更新公式就不写了,下面只写一些理解。
    • 可以参考这篇文章 Written Memories: Understanding, Deriving and Extending the LSTM,以下内容也有很大一部分也源于这里。这里也简单导读一下这篇文章:
      • 这篇文章讲了 LSTM 设计的初衷和原则、然后根据这些原则推导出了 GRU 的设计。我先前一直觉得 LSTM 的门设计的很自然,而 GRU 是一种对 LSTM 的很奇怪的简化。而上面这篇文章成功地说服了我 GRU 的设计很合理,觉得 GRU 的门设计很奇怪的同学不要错过这篇文章。
      • 这篇文章认为 LSTM 里的读写顺序和状态分裂为 (c, h) 等设计很奇怪(文中叫 LSTM hiccup)。事实上,这也确实给初学者理解和编程实现带来了很多麻烦,例如 TF 中专门有一个类叫 LSTMStateTuple。不过,也可以从别的视角来解释和完善这种设计,参见 从信息隐匿的角度谈 LSTM:从 Stack 到 Nest
    • 理论上,各个门的值应该在 [0, 1] 之间。但是如果你真正训过一些表现良好的网络并且查看过门的值,就会发现很多时候门的值都是非常接近 0 或者 1 的,而类似于 0.2/0.5 这样的中间值很少。从直觉上我们希望门是二元的,控制信息流动的通和断,事实上训练良好的门也确实能达到这种效果。
    • 加入门机制可以解决普通 RNN 的梯度消失的问题,网络上相关的文章很多,这里就不仔细推了。
    • 更重要的是,门可以控制信息变形(information morphing)和选择性(selectivity)
      • 这一点可以这么思考:LSTM 1997 年刚提出的时候没有 forget gate,各个步骤的 c 之间的导数就有一个完美的单位阵,号称 Constant Error Carousel。后来 2000 年 Gers 给它加入了 forget gate,假如梯度消失是网络难以训练的最重要的原因,那为什么加入 forget gate 有可能导致梯度被截断,反而模型效果更好了?
      • 选择性体现在,想让信息流动的时候的就让它流动,不想让它流动的时候就关掉。例如做情感分析时,只让有情感极性的词和关联词等信息输入进网络,把别的忽略掉。这样一来,网络需要记忆的内容更少,自然也更容易记住。同样以 LSTM 为例,如果某个时刻 forget gate 是 0,虽然把网络的记忆清空了、回传的梯度也截断掉了,但这是 feature,不是 bug。
        • 这里举一个需要选择性的任务:给定一个序列,前面的字符都是英文字符,最后以三个下划线结束(例如 "abcdefg___");要求模型每次读入一个字符,在读入英文字符时输出下划线,遇到下划线后输出它遇到的前三个字符(对上面的例子,输出应该是 "_______abc")。显然,为了完成这个任务,模型需要学会记数(数到 3),只读入前三个英文字符,中间的字符都忽略掉,最终遇到 _ 时再输出它所记住的三个字符。“只读前三个字符”体现的就是选择性。
      • 信息变形体现在,模型状态在跨时间步时不存在非线性变换,而是加性的。假如普通 RNN 的状态里存了某个信息,经过多个时间步以后多次非线性变换把信息变得面目全非了,即使这个信息模型仍然记得,但是也读取不出来它当时到底存的是什么了。而引入门机制的 RNN 单元在各个时间步是乘上一些 0/1 掩码再加新信息,没有非线性变换,这就导致网络想记住的内容(对应掩码为 1)过多个时间步记得还是很清楚。


【跨尺度连接】

  • 代表作
    • CW-RNN: Clockwise RNN
      • 大意是普通 RNN 都是隐层从前一个时间步连接到当前时间步。而 CW-RNN 把隐层分成很多组,每组有不同的循环周期,有的周期是 1(和普通 RNN 一样),有的周期更长(例如从前两个时间步连接到当前时间步,不同周期的 cell 之间也有一些连接。这样一来,距离较远的某个依赖关系就可以通过周期较长的 cell 少数几次循环访问到,从而网络层数不太深,更容易学到。
    • Dilated RNN
      • 和 CW-RNN 类似,只是 CW-RNN 是在同一个隐层内部分组,Dilated RNN 是在不同的层分组:最下面的隐层每个时间步都循环,较高的隐层循环周期更长些,从而有效感受野更大。
    • NARX RNN: Nonlinear Auto-Regressive eXogenous RNN
      • 详见 Learning Long-Term Dependencies in NARX Recurrent Neural Networks
      • 类此以上两种方法,但前面的方法是把隐层单元分组,有的单元单步循环,有的多步循环;而这种方法是让每个隐层单元都在不同尺度上循环,例如某个隐层状态直接依赖于它的前一个、前两个、直到前 n-1 个隐层状态
      • 如果说普通的 RNN 是一阶递推式,这种就是 n 阶递推式
    • TKRNN: Temporal Kernel RNN
  • 解释
    • 既然学习长期依赖很难,那就手动把依赖的步数缩短,然后学习短期依赖就可以了。思想有点儿类似于 ResNet 中的 skip-connection(但是跨越多个时间步的连接用的不是单位阵,而是需要学习的稠密矩阵),使得模型输出层可以看到之前不同时间步的信息,进而达到类似模型集成的效果。


【特殊初始化(及其维持)】

  • 代表作
    • IRNN: Identity Recurrent Neural Networks
      • 参考 [1504.00941] A Simple Way to Initialize Recurrent Networks of Rectified Linear Units
      • 大意就是普通的 RNN,把激活函数换成 ReLU,隐层的自循环矩阵用单位阵 I 初始化、偏置项设置成 0,然后在语言模型等任务上效果跟 LSTM 差不多。
      • 恒等初始化 + ReLU 的效果是让模型一开始先记住所有信息,并且能不变形地复制到下一时间步(暂时不考虑负元素被置为零的问题),之后模型再慢慢学习如何平衡长短距离的依赖信息。
      • LSTM 在 1997 年刚出来的时候没有遗忘门(相当于遗忘门恒为 1),其实也是这种设计思想
      • 同时作者也说了,如果该任务不需要长期依赖,把自循环矩阵初始化成 \alpha I (\alpha \in [0, 1)) 更好,有助于模型快速忘掉长期信息。
    • np-RNN: normalized-positive definite RNN
      • 详见 [1511.03771] Improving performance of recurrent neural network with relu nonlinearity
      • IRNN 的续作。IRNN 效果也不错,但是对超参数的选取较为敏感。这篇文章从 dynamic system 的角度分析了 ReLU RNN 的 fixed point (就是有些隐层状态会一直卡在同一个地方迭代后几乎不动)的演化行为,提出新的初始化方法,比 IRNN 更容易训练成功
      • 具体的做法是将 ReLU RNN 的自循环矩阵初始化成最大特征值为 1 的半正定矩阵。根据文章里的分析,这种设置可以尽量避免隐层状态退化到低维的 stable fixed point 上,从而有助于训练过程。
    • RIN: Recurrent Identity Network
    • Unitary-RNN
      • 详见 [1511.06464] Unitary Evolution Recurrent Neural Networks
      • 正交矩阵有个好处是特征值都是 1,连乘很多次也不会爆炸或者消失;很多框架默认的 RNN (包括 LSTM 等现代 RNN 单元)初始化方法都包含正交初始化,也是这一思路的体现;当然现在也有很多人直接使用 uniform distribution 初始化,不用正交初始化了。
      • Unitary-RNN 在正交化这一点上更是一条路走到黑,初始化时使用酉矩阵(正交矩阵的复数版本,它的权重参数是复数而不是实数),然后通过特殊的参数化技巧把权重矩阵表达成一个形式上较复杂但计算上很高效的形式,使得参数更新后也能保证新的参数矩阵仍然是酉矩阵,从而训练过程中一直都不会发生梯度消失或爆炸。
      • 其实 Unitary-RNN 这一系列复杂操作的目的仅仅是高效地维持矩阵的正交性。最简单的实现这一目的的方法是梯度投影法:先梯度下降,然后把参数矩阵投影成正交阵,只是这么做效率太低。
      • 文章里的实验在 copying memory problem 上效果拔群,远超 LSTM。不过这是一个刻意构造的 toy task 了。
    • IndRNN:
      • 详见 Independently Recurrent Neural Network (IndRNN): Building A Longer and Deeper RNN
      • 大意是把自循环矩阵设置为对角阵,这样就把隐层的各个维度分离开了,导数的形式就从矩阵连乘变成了标量连乘,更容易分析
      • 但是普通的标量连乘还是会导致梯度消失/爆炸,所以 IndRNN 在训练的时候会对权重做裁剪(是 weight clipping,而不是 gradient clipping),把权重强行控制在某个范围,就可以使得它自乘若干次以后不至于太大或者太小
      • 吐个槽,这种做法感觉比 gradient clipping 还粗暴,居然还能奏效……
    • singular value clipping
      • 详见 Preventing Gradient Explosions in Gated Recurrent Units
      • 大意是说 RNN 的梯度爆炸是由非线性系统的 bifurcation point 引起的。RNN 的 fixed point (即 RNN 运行一步后隐状态保持不变的点)有些是稳定的,有些是不稳定的,在 bifurcation point 附近梯度会暴增产生梯度爆炸。文章在某种特殊情形下找到了一个 stable fixed point,然后通过在训练过程中裁剪参数矩阵的 singular value,使得参数矩阵一直处于他找到的特殊情形下,从而参数永远在比较稳定的区域内更新,不会发生梯度爆炸,训练过程更稳定。
      • 注意,即便是 GRU/LSTM 等现代 RNN 单元也只能解决梯度消失,不能解决梯度爆炸。很多人都以为 GRU/LSTM 可以解决梯度爆炸问题,这是一个常见误解,详情请参考 Why can RNNs with LSTM units also suffer from "exploding gradients"?),梯度爆炸现在主要靠 gradient clipping 来解决。
      • 最有趣的地方是,这篇文章和上面的 np-RNN 都基于对 dynamic system 和 fixed point 的分析,但是出发点却完全相反:np-RNN 想让模型离开 fixed point,而这篇文章想让模型别走出 fixed point 附近。我个人对于 dynamic system 不太熟悉,不过 fixed point 可能是 RNN 有时会出现连续重复预测现象的原因。到底应该避开 fixed point 还是应该利用它可能还需要进一步的研究。
    • SCRNN:Structurally Constrained Recurrent Neural Network
      • 大意是把隐层分成两部分,其中一部分的更新方法是每次自乘 \alpha\in[0, 1) 再加该步的新信息,被叫做 context feature;另一部分则类似于普通 RNN,根据前一时间步的值和当前时间步的 context feature 来更新。
      • 这里的 context feature 的更新也是加性的,不存在非线性变换。\alpha 的值接近 1(例如 0.95),近似于恒等映射,让保存的历史信息衰减得慢一些。
      • 这里将 RNN 的隐层分裂为两部分,其实也类似于 LSTM 中 (c, h) 分裂的操作。其中 context feature 类似于 LSTM 的 c,储层长期信息。
  • 解释
    • 特殊初始化(及其维持)大致有三种方法:(近似)恒等映射、正交化、参数范围控制。其中(近似)恒等映射( I\alpha I )以 IRNN 等为代表,可以被看成是一种简化版的遗忘门;正交化以 Unitary-RNN 为代表;参数范围控制以 IndRNN (weight clipping)/gradient clipping/singular value clipping 为代表。
    • 对于(近似)恒等映射而言,这种方案首先让模型记住所有历史信息,然后在学习过程中慢慢恢复模型的表达能力,其实是牺牲了模型短期内拟合复杂函数的能力。设想,假如数据的分布是随着时间变化的,门机制可以更快地对新数据带来的改变做出适应,而这种方案可能就凉了。
    • 除了在 RNN 中,CNN 也可以利用类似的技巧。例如这篇文章( If resnets are the answer, then what is the question? )里提出 LL-Init(looks linear initialization)的技巧,可以在不加 skip-connection 的情况下训练 200 层深的网络,办法就是先把网络初始化成一个线性函数,然后让它慢慢恢复自己的复杂结构的拟合能力。另外这篇论文对 ResNet 所解决的问题论证得很漂亮(很多人以为 ResNet 是解决梯度弥散问题,但并不是!因为 BatchNorm 就足够解决梯度弥散了,ResNet 解决的显然是另外的问题),值得一读。此外还有一篇 DiracNet (DIRACNETS: TRAINING VERY DEEP NEURAL NETWORKS WITHOUT SKIP-CONNECTIONS)和该文相似,不过实验更扎实一些。

此外,还有一些方法是上述三种思想结合的产物,例如:

  • Statistical Recurrent Unit,详见 [1703.00381] The Statistical Recurrent Unit
    • 大意是认为模型知道所有历史信息的话对预测会有帮助,所以采用 moving average 对序列每一步得到的一些统计量做更新,然后模型以此为基础得到输出。使用的主要机制是近似恒等映射:论文中的公式 (7) 相当于把 LSTM 遗忘门固定成 \alpha
    • 在我看来这个模型其实和 LSTM 差不多:SRU 中的 statistics (\mu) 对应 LSTM 的 cell state (c), SRU 中的 \varphi 对应 LSTM 的 new candidate state (\tilde{c}) ,SRU 中的 o 对应 LSTM 的 h。但是作者讲故事的功底很好,不得不服= =
    • 由于 \alpha<1,因此 \mu 会随着时间指数衰减;而不同的 \alpha 的值对应的衰减速度不同,使用多个不同的 \alpha 就能使得输出 o 看到不同衰减速度的统计信息,这里相当于跨尺度连接
  • Fourier Recurrent Unit,详见 Learning Long Term Dependencies via Fourier Recurrent Units
    • 本文从上面的 Statistical Recurrent Unit 发展而来,改进了“遗忘门” \alpha 呈指数衰减的缺点。具体做法是将统计信息 u 的更新公式改为 u^{(t)} = u^{(t-1)} + \dfrac{1}{T} C^{(t)} \cdot h^{(t)} ,其中 C^{(t)} 是一个包含不同频率余弦系数的矩阵。这里去掉了衰减因子 \alpha ,使用了恒等映射。
    • 最终的效果是,如果把 u 沿时间轴完全展开,可以发现不会再像 exponential moving average 那样重短期信息、轻长期信息,而是在不同余弦频率下起起伏伏,总有合适的频率能够保存到很多时间步之前的信息。这相当于是一个动态权重的、无限长时间步的多尺度连接,实在是厉害。
    • 文章证明了 FRU 的梯度是有上下界 bound 住的,梯度大小和时间步数无关,因此没有梯度消失和梯度爆炸问题,这一点非常厉害,是其他模型所不具备的。



三种方案当然各有优劣,其中门的好处自然是选择性,并且是现在比较流行的方案,对于各种任务效果都比较 robust。而后两种方案网络结构更简单,可能更适应于某些特殊的计算平台,或者是对于不太需要选择性的任务可能会有很好的效果,如以下回答中所描述的那样:

RNN中为什么要采用tanh而不是ReLu作为激活函数?www.zhihu.com图标

这个回答最后的总结 “『如果把ReLU RNN的参数增加四倍到跟LSTM的参数一样多,它应该是会稳定好过LSTM的』这个应该限定到我熟悉的语音识别任务,对NLP等等可能不大对。”其实说的也就是我在这里论述的选择性的问题:毕竟语音识别所需要的选择性不太强(和情感分析等相比)。

这里,特别要感谢 @chaopig@saizheng 在以上答案中的讨论,以及 @chaopig 补充的 QRNN, NARX RNN 和 TKRNN。

编辑于 2018-06-19

文章被以下专栏收录