测试两种新 Attention 机制:gMLP 和 AFT(结论:AFT 效果好)

测试两种新 Attention 机制:gMLP 和 AFT(结论:AFT 效果好)

首先看公式(带 causal mask)。本文提出一个通用形式:

\textbf{ATT}_{t,c} = R_{t,c}\cdot \frac{\sum_{u \leq t}W_{t,u,c}\cdot V_{u,c}}{\sum_{u \leq t}W_{t,u,c}}

其中 R_{t,c} 是当前 t 位置在 c 通道接受信息的意愿, V_{u,c} 是 u 位置发送的信息内容, W_{t,u,c} 是 t 和 u 在 c 通道的关联强度(个人看法,目前的语言模型实际做的是,由旧字向新字发送信息,逐渐为新字消歧。W 是旧字和新字的关联强度)。

普通注意力。 W_{t,u,c} =\exp(Q_t\cdot K_u) 其中 Q dot K 在 c 对应的头中计算。需要额外 positional encoding 因为缺少 t 和 u 的距离信息。Q K V 由 linear(c, c) 生成。

\textbf{MHA}_{t,c} = \frac{\sum_{u \leq t}\exp(Q_t\cdot K_u)\cdot V_{u,c}}{\sum_{u \leq t}\exp(Q_t\cdot K_u)}

普通注意力,加明确的 time-weighting W_{t,u} (见 PENG Bo:只需几行代码,改进Transformer自注意力机制(几乎不增加计算量) )。可以不需要额外 positional encoding,因为已提供明确的 t 和 u 关联强度。

\textbf{MHATW}_{t,c} = \frac{\sum_{u \leq t} W_{t,u} \cdot \exp(Q_t\cdot K_u)\cdot V_{u,c}}{\sum_{u \leq t} W_{t,u} \cdot \exp(Q_t\cdot K_u)}

gMLP(Pay Attention to MLPs)。其中 R 和 V 由 gelu(linear(c, 3c)) 生成。它也使用了明确的 circulant W_{t,u} = f(t-u) 。同时它没有分母的归一化,因为 \sum_{u \leq t} W_{t,u} 的变化很简单,网络可在其它地方学会。由于它的公式很简化,所以需要至少 linear(c, 2c) 才能有好效果。用 linear(c, c) 会明显变差。所以注意下面的 c 维度较高。

\textbf{gMLP}_{t,c} = R_{t,c}\cdot \sum_{u \leq t} W_{t,u}\cdot V_{u,c}

AFT(An Attention Free Transformer)。其中 R 由 sigmoid(linear(c, c)) 生成,接受度在 0~1 范围。K 和 V 由 linear(c, c) 生成。同时,这里的 \exp(K_{u,c}) 可理解为 u 位置在 c 通道的发送强度,\exp 的优点是容易生成极强和极弱的强度,而且不会出现负强度,这些特点都适合语言模型。

\textbf{AFT}_{t,c} = R_{t,c}\cdot \frac{\sum_{u \leq t}W_{t,u}\cdot\exp(K_{u,c})\cdot V_{u,c}}{\sum_{u \leq t}W_{t,u}\cdot\exp(K_{u,c})}

注意 AFT 原始论文将 W 做低秩分解,但是 circulant matrix 本身是满秩的(它的 eigenvector 就是 fourier series),分解会损害对称性(不过,后文我们会看到 W 在大部分区域很均匀,所以用低秩分解效果也还可以)。用本文稍后的分解更好。

总结:

RWV
MHA1exp(Q*K)V
MHATW1W*exp(Q*K)V
gMLPgelu(R)Wgelu(V)
AFTsigmoid(R)W*exp(K)V

后文测试,AFT 在改进后的 速度 & 精度 & 参数量,都是其中较好的,可能是因为它在 R 和 W 都有足够的表达力。有算力的朋友可试试架构搜索,看有没有更好的 R W V 形式。


本文用更细致的 W,改进 MHATW 和 gMLP 和 AFT 的效果。

首先,原始的 circulant W 是:

W_{t,u,c}=f(t-u)

然后改成"多头",允许不同 c 使用不同的 f。例如将 c 划分为 8 个 h,那么可以有 8 种不同的 f。公式变成:

W_{t,u,c}=f_h(t-u)

训练出的 8 个不同 f 的例子:

最右边是最靠近 t 的情况,所以 f 的变化丰富。中间是离得远的情况,所以很均匀(所以 W 在大部分区域很均匀)。

最左边为什么有变化?是因为左边的字只有很短的前文,它在模型中的运作模式不同,W 对它做了特别处理。例如,最左边的那个字,无论经过多少层,都看不到任何前文,只是自己和自己作用。

由于模型有 context window 长度限制,不同 t 的 context window 长度不同(例如有些只能看见很短的前文),所以模型应加入额外的时间因子 \alpha\beta

W_{t,u,c}=f_h(t-u)\cdot \alpha_h(t) \cdot \beta_h(u)

最后,整体也可以乘一个 \gamma(t) 因子(因为 \alpha\beta 都在 W 中,它们会被归一化掉):

\textbf{ATT}_{t,c} = \gamma(t) \cdot R_{t,c}\cdot \frac{\sum_{u \leq t}W_{t,u,c}\cdot V_{u,c}}{\sum_{u \leq t}W_{t,u,c}}

如果把模型写成 RNN 的形式,恢复时间对称性,就可以去掉这些 \alpha \beta \gamma 因子。否则,加上去有助于在 context 很短时的生成。训练后会看到网络使用这些因子。


我们用单层模型(8头 512维 128长度)测试拟合力,参数量约 9e6。FFN 用 GeGLU。数据用 8646542 characters, 5744 unique 的金庸全集。优化用 AdaBelief 1e-3(小模型用大 lr 学得快)cosine decay to 1e-4。Batchsz 64。

gMLP 论文只用 gMLP 层,所以需要额外 LayerNorm,额外初始化方法,以及 linear(c, 3c)。这里有 FFN 就不需要这些了,用 linear(c, 2c) 可将速度和参数量调至类似其它模型的水平。另外这里的 gMLP 没有加额外 SA,所以可以看到表达力略弱。

全部模型都加了我之前写的 time-weighting 和 time-mixing。这里的 MHA+ 还加了 rotary embedding,加了点 Talking-Heads Attention。这些改进都有效果。

另外所有模型都不需要额外的 positional encoding。

Perplexity(越小越好)速度(越大越好)参数量(越小越好)
MHA+14.0712.69.3e6
gMLP14.4212.99.8e6
AFT13.5313.49.3e6

结论:改进的 AFT 是目前较好的选择。

在 context length 增大后,速度 AFT > gMLP > MHA+ 会有明显体现。


我也炼了 24 层(16头 1024维 512长度)的 AFT,用 40G 中文语料训练,生成效果合理,说明 AFT 也适合大规模语言模型。

建了个文本生成的交流 QQ 群 143626394 (加入时请简单自我介绍)。

之前用 MHA 炼的 12 层模型下载:

其中还提出一种新的采样方法,生成文本的观感比 top-p top-k 好。

编辑于 2021-08-02 16:24