首发于NLPCAB
结构剪枝:要个4层的BERT有多难?

结构剪枝:要个4层的BERT有多难?

更新:

感谢 @AlbertMaMa 同学,layerdrop已经有源码了(在fairseq中,要自己找一下):

https://github.com/pytorch/fairseq/tree/master/examples/layerdropgithub.com

想看transformer剪枝文章的起因是美团搜索团队的一篇文章[1],里面提到他们将BERT裁剪为四层,在query分类的业务场景下效果不降反升,简直香。但是文章里也没有透露具体的做法,我身边也没认识到美团搜索团队的同学,于是决定自己动手调研一波。

首先,剪枝的主要原因是提升推理速度,另外是希望去掉在个别case上过拟合的权重,提升模型泛化能力

我认为剪枝的研究点主要在以下两方面:

  1. 剪掉什么?
    这里可以分为weight pruning(WP)和structured pruning(SP)。WP指将矩阵中的某个权重连接置为0,处理完后通过稀疏矩阵进行运算,或者将某一行或一列置为0(神经元),仍用矩阵进行运算。但WP的缺点是对矩阵进行改变,影响运算效率,所以不进行更深入的调研了,有兴趣的同学可以参考[2]中提供的论文。SP则是对一组权重进行剪枝,比如去掉Attention heads、Attention mechanism、FFN层或者整个Transformer层,如果效果不下降的话对工程十分友好,因此下文的剪枝都指结构性剪枝。
  2. 什么时候剪?
    这里可以分为预训练、训练/精调、测试三个阶段,是应该在训练/预训练/精调的时候去学习如何剪,还是预训练之后剪掉一些再精调,还是直接在预测的时候剪呢?各位可以自己先想想~

方案一:预测时剪掉Attention Heads

在一篇NIPS2019[3]中,作者等人主要研究了注意力头对结果的影响,评估的任务主要是encoder-decoder的NMT和BERT的MNLI。在预测阶段mask掉head进行研究后可以发现:

在翻译任务中,只有8(/96)个头会对预测结果有显著影响,大多数头在预测时都是多余的。

上图中,作者只保留了当前层中最好的一个head。可以看出BERT中大部分层在只有一个head的情况下,对结果的影响都不大。因此通过去掉head来压缩模型是可行的,虽然head在底层是并行计算的,在速度上可能提升不大,作者也提供了测试结果:

在剪掉50%的BERT head情况下,是有一定比例的速度提升的。不过作者同时也证明了剪掉50%的head会有效果下降:

大概降了1-2个点的感觉。所以比起我们既想压缩模型,又想提升效果的目标还有一些距离。

方案二:训练时加入Dropout,预测时剪掉Layer

文章[4]正在投稿ICLR2020,来自Facebook AI Research,提出了类似Dropout的LayerDrop,在训练时随机mask掉一组权重,让模型适应这种“缺胳膊少腿”的情况,变得更加robust,然后在预测时直接去掉部分结构进行预测。评估的任务也非常多样,有NMT、摘要、语言模型和NLU。

上图左边是剪去不同结构的效果,可见只剪layer效果就很好了。右图是剪枝的策略,效果简单又好的是Every other(每隔一个)。因此在之后的实验中都是剪去了every other layer,具体的方法参见下式,比如剪枝概率为0.5,那就是剪掉除2后余数为0的层(0, 2, 4, ..., 10)。

d ≡ 0(mod\lfloor\frac{1}{p}\rfloor) \\

在*预训练阶段应用LayerDrop,剪枝之后再精调*的效果要好于蒸馏和同等量级的模型:

遗憾的是,好像并没在文中看到剪枝后模型和原BERT的效果比较。不过去BERT原论文里查了下,BERT base在上述任务的表现分别是:84.4/86.7/88.4/92.7,比剪枝成6层的模型好1-2个点,但速度却慢一半左右。

另外,只将LayerDrop应用在预训练+精调过程中还有提升效果的奇效,而且能提升一定的训练速度

看到LayerDrop这篇文章时觉得已经找到了心中最佳,后续有时间会撸一下代码,其他看了下但没有细读的论文有:

  • [5] 按权重大小进行weight pruning
  • [6] 类似方案二,在SQUAD2.0上进行实验
  • [7] 提出了一种RPP的weight pruning方法

最后我想说,如果美团的同学看到了这篇文章,要不要和大家分享一下呀~

参考:

  1. 美团BERT的探索和实践: mp.weixin.qq.com/s/433H
  2. BERT 瘦身之路:Distillation,Quantization,Pruning: zhuanlan.zhihu.com/p/86
  3. Are Sixteen Heads Really Better than One?: arxiv.org/abs/1905.1065
  4. Reducing Transformer Depth on Demand with Structured Dropout: arxiv.org/abs/1909.1155
  5. Compressing BERT: Studying the Effects of Weight Pruning on Transfer Learning: openreview.net/forum?
  6. Pruning a BERT-based Question Answering Model: arxiv.org/abs/1910.0636
  7. Reweighted Proximal Pruning for Large-Scale Language Representation: arxiv.org/abs/1909.1248

编辑于 09-09

文章被以下专栏收录