原创:用 RNN 就足以精确建模语言!如何将 transformer 提速 n 倍,且性能更佳

原创:用 RNN 就足以精确建模语言!如何将 transformer 提速 n 倍,且性能更佳

你还在考虑线性注意力吗?我发现,你只需一个 RNN,只需看前一个字的状态,效果就能媲美 transformer。

我的 RWKV-v2-RNN 同时有 transformer 和 RNN 的优点。既可并行训练,又可串行训练。既可并行运行,又可串行运行。

在做单向序列建模的知友,欢迎联系我。我帮你调起来,让你试试效果(双向其实也可以,正走一遍,反走一遍)。

为什么普通 RNN 难并行?是因为在逐字计算【衰减速度】。

而 RWKV-v2-RNN 的每个通道,都有固定的可训练的衰减速度。

训练后会发现,底部层的平均衰减更快,对应短程信息;顶部层的平均衰减更慢,对应长程信息。

模型将信息从【快衰减通道】投射到【慢衰减通道】,即可实现记忆。

模型将信息从【慢衰减通道】投射到【快衰减通道】,即可实现遗忘。

模型在 CPU 也能飞速运行,每个字都写得一样快。我把一个 6 层小模型做成了网页版,在你的手机就可以直接写小说(其实 24 层都很快,主要问题是国内从 github 下载权重会太慢):

测试 27M 参数模型(6层,512维嵌入),就足以在 enwik8 获得 dev BPC 0.72!你可以下载并验证这个成绩。

它采用我的 RWKV v2 RNN 模型。从初始化到运行细节,包含大量新技巧,现全部公开:

训练和运行代码,都在项目的 RWKV v2 RNN 文件夹中。训练代码会自动编译一个 CUDA Kernel。训练好的 27M enwik8 模型在 github.com/BlinkDL/RWKV

我正在 the Pile 炼 RWKV-v2-RNN。等一个月,就知道是否能挑翻 transformer 了。只要能 match performance 就够,因为 RNN 的时间和空间复杂度,比 transformer 强太多。欢迎围观:知乎 - 安全中心

伪代码如图:

简单介绍:

  • a 和 b 是 kv 和 k 的 EMA(exponential moving average)。
  • c 和 d 是 a 和 b 加上 self-attention 效应(真·自注意力,原地注意力)。
  • c / d 是记忆机制,因为如果某个字在某个通道的 k 很强,且 W 接近 1,那么这个字就会被后文记住。
  • 这个 abcd 和 RKV 设计是类似苹果的 Attention-free Transformer。其实很多线性注意力都可以变成 RNN,例如 Linear Transformers 的研究。
  • T K V R W X P 是可以训练的参数,它们的尺寸也写在表中了。
  • W的初始化是看到Alibi编码才去试这种形式,发现确实不错。我更精确地设置了不同层用不同初始化,前期层负责短程,后期层负责长程。
  • 严格说,模型还有个 HeadQK 机制,这个会看一遍前文,不过很快。它可以让模型从前文复制或避免某些字。写小说需要拷贝人名,所以我想到了这个方法。后来看到苏剑林的 GlobalPointer 也有类似的想法。
  • ReluSquare 来自 Primer,我测试效果确实不错。

同时我使用了非常精细的初始化,对于中英文都效果很好。因此,可以直接使用 PostLN 架构,而且训练速度极快。大家自己研究吧。

欢迎大家在音乐和音频和语音和图像也试试这个。模型还可以继续改进,我写了几条在上图的 TODO 里面,欢迎大家水论文。

编辑于 2022-04-20 15:33