首发于谓之小一

LSTM如何解决RNN带来的梯度消失问题

本篇文章参考于 RNN梯度消失和爆炸的原因Towser关于LSTM如何来避免梯度弥散和梯度爆炸?的问题解答、Why LSTMs Stop Your Gradients From Vanishing: A View from the Backwards Pass
看本篇文章之前,建议自行学习RNN和LSTM的前向和反向传播过程,学习教程可参考刘建平老师博客循环神经网络(RNN)模型与前向反向传播算法LSTM模型与前向反向传播算法

具体了解LSTM如何解决RNN所带来的梯度消失问题之前,我们需要明白为什么RNN会带来梯度消失问题。

1. RNN梯度消失原因

如上图所示,为RNN模型结构,前向传播过程包括,

  • 隐藏状态: h^{(t)} = \sigma (z^{(t)}) = \sigma(Ux^{(t)} + Wh^{(t-1)} + b) ,此处激活函数一般为 tanh
  • 模型输出: o^{(t)} = Vh^{(t)} + c
  • 预测输出: \hat{y}^{(t)} = \sigma(o^{(t)}) ,此处激活函数一般为softmax。
  • 模型损失: L = \sum_{t = 1}^{T} L^{(t)}

RNN反向传播过程中,需要计算 U, V, W 等参数的梯度,以 W 的梯度表达式为例,

 \frac{\partial L}{\partial W} = \sum_{t = 1}^{T} \frac{\partial L}{\partial y^{(T)}} \frac{\partial y^{(T)}}{\partial o^{(T)}} \frac{\partial o^{(T)}}{\partial h^{(T)}} \frac{\partial h^{(T)}}{\partial h^{(t)}} \frac{\partial h^{(t)}}{\partial W} \\

现在需要重点计算 \frac{\partial h^{(T)}}{\partial h^{(t)}} 部分,展开得到,  \frac{\partial h^{(T)}}{\partial h^{(t)}} = \frac{\partial h^{(T)}}{\partial h^{(T-1)}} \frac{\partial h^{(T - 1)}}{\partial h^{(T-2)}} ...\frac{\partial h^{(t+1)}}{\partial h^{(t)}} = \prod_{k=t + 1}^{T} \frac{\partial h^{(k)}}{\partial h^{(k - 1)}} = \prod_{k=t+1}^{T} tanh^{'}(z^{(k)}) W \\

那么 W 的梯度表达式也就是,

\frac{\partial L}{\partial W} = \sum_{t = 1}^{T} \frac{\partial L}{\partial y^{(T)}} \frac{\partial y^{(T)}}{\partial o^{(T)}} \frac{\partial o^{(T)}}{\partial h^{(T)}} \left( \prod_{k=t + 1}^{T} \frac{\partial h^{(k)}}{\partial h^{(k - 1)}} \right) \frac{\partial h^{(t)}}{\partial W} \  \\ = \sum_{t = 1}^{T} \frac{\partial L}{\partial y^{(T)}} \frac{\partial y^{(T)}}{\partial o^{(T)}} \frac{\partial o^{(T)}}{\partial h^{(T)}} \left( \prod_{k=t+1}^{T} tanh^{'}(z^{(k)}) W \right) \frac{\partial h^{(t)}}{\partial W} \  \\

其中 tanh^{'}(z^{(k)}) = diag(1-(z^{(k)})^2) \leq 1 ,随着梯度的传导,如果 W 的主特征值小于1,梯度便会消失,如果W的特征值大于1,梯度便会爆炸。

需要注意的是,RNN和DNN梯度消失和梯度爆炸含义并不相同。RNN中权重在各时间步内共享,最终的梯度是各个时间步的梯度和。因此,RNN中总的梯度是不会消失的,即使梯度越传越弱,也只是远距离的梯度消失。 RNN所谓梯度消失的真正含义是,梯度被近距离梯度主导,远距离梯度很小,导致模型难以学到远距离的信息。 明白了RNN梯度消失的原因之后,我们看LSTM如何解决问题的呢?

2. LSTM为什么有效?


如上图所示,为RNN门控结构,前向传播过程包括,

  • 遗忘门输出: f^{(t)} = \sigma(W_fh^{(t-1)} + U_fx^{(t)} + b_f)
  • 输入门输出: i^{(t)} = \sigma(W_ih^{(t-1)} + U_ix^{(t)} + b_i) , a^{(t)} = tanh(W_ah^{(t-1)} + U_ax^{(t)} + b_a)
  • 细胞状态: C^{(t)} = C^{(t-1)}\odot f^{(t)} + i^{(t)}\odot a^{(t)}
  • 输出门输出: o^{(t)} = \sigma(W_oh^{(t-1)} + U_ox^{(t)} + b_o) , h^{(t)} = o^{(t)}\odot tanh(C^{(t)})
  • 预测输出: \hat{y}^{(t)} = \sigma(Vh^{(t)}+c)

RNN梯度消失的原因是,随着梯度的传导,梯度被近距离梯度主导,模型难以学习到远距离的信息。具体原因也就是 \prod_{k=t+1}^{T}\frac{\partial h^{(k)}}{\partial h^{(k - 1)}} 部分,在迭代过程中,每一步 \frac{\partial h^{(k)}}{\partial h^{(k - 1)}} 始终在[0,1]之间或者始终大于1。

而对于LSTM模型而言,针对 \frac{\partial C^{(k)}}{\partial C^{(k-1)}} 求得,

\frac{\partial C^{(k)}}{\partial C^{(k-1)}} = \frac{\partial C^{(k)}}{\partial f^{(k)}} \frac{\partial f^{(k)}}{\partial h^{(k-1)}} \frac{\partial h^{(k-1)}}{\partial C^{(k-1)}} + \frac{\partial C^{(k)}}{\partial i^{(k)}} \frac{\partial i^{(k)}}{\partial h^{(k-1)}} \frac{\partial h^{(k-1)}}{\partial C^{(k-1)}}  \\ + \frac{\partial C^{(k)}}{\partial a^{(k)}} \frac{\partial a^{(k)}}{\partial h^{(k-1)}} \frac{\partial h^{(k-1)}}{\partial C^{(k-1)}} + \frac{\partial C^{(k)}}{\partial C^{(k-1)}}  \\

具体计算后得到,

\frac{\partial C^{(k)}}{\partial C^{(k-1)}} = C^{(k-1)}\sigma^{'}(\cdot)W_fo^{(k-1)}tanh^{'}(C^{(k-1)})  \\ + a^{(k)}\sigma^{'}(\cdot)W_io^{(k-1)}tanh^{'}(C^{(k-1)})  \\  + i^{(k)}tanh^{'}(\cdot)W_c*o^{(k-1)}tanh^{'}(C^{(k-1)})  \\ + f^{(t)}  \\

 \prod _{k=t+1}^{T} \frac{\partial C^{(k)}}{\partial C^{(k-1)}} = (f^{(k)}f^{(k+1)}...f^{(T)}) + other  \\

在LSTM迭代过程中,针对 \prod_{k=t+1}^{T} \frac{\partial C^{(k)}}{\partial C^{(k-1)}} 而言,每一步\frac{\partial C^{(k)}}{\partial C^{(k-1)}} 可以自主的选择在[0,1]之间,或者大于1,因为 f^{(k)} 是可训练学习的。那么整体 \prod _{k=t+1}^{T} \frac{\partial C^{(k)}}{\partial C^{(k-1)}} 也就不会一直减小,远距离梯度不至于完全消失,也就能够解决RNN中存在的梯度消失问题。LSTM虽然能够解决梯度消失问题,但并不能够避免梯度爆炸问题,仍有可能发生梯度爆炸。但是,由于LSTM众多门控结构,和普通RNN相比,LSTM发生梯度爆炸的频率要低很多。梯度爆炸可通过梯度裁剪解决。

LSTM遗忘门值可以选择在[0,1]之间,让LSTM来改善梯度消失的情况。也可以选择接近1,让遗忘门饱和,此时远距离信息梯度不消失。也可以选择接近0,此时模型是故意阻断梯度流,遗忘之前信息。更深刻理解可参考LSTM如何来避免梯度弥散和梯度爆炸?中回答。

编辑于 04-25

文章被以下专栏收录