都9102年了,别再用Adam + L2 regularization了

都9102年了,别再用Adam + L2 regularization了

前言

L2 regularization 和 Weight decay 只在SGD优化的情况下是等价的。

1.weight decay

Weight decay是在每次更新的梯度基础上减去一个梯度( \boldsymbol{\theta} 为模型参数向量, \nabla f_{t}\left(\boldsymbol{\theta}_{t}\right)t 时刻loss函数的梯度, \alpha 为学习率):

\boldsymbol{\theta}_{t+1}=(1-\lambda) \boldsymbol{\theta}_{t}-\alpha \nabla f_{t}\left(\boldsymbol{\theta}_{t}\right)\\

2.L2 regularization

L2 regularization是给参数加上一个L2惩罚( f_{t}(\boldsymbol{\theta}) 为loss函数):

f_{t}^{r e g}(\boldsymbol{\theta})=f_{t}(\boldsymbol{\theta})+\frac{\lambda^{\prime}}{2}\|\boldsymbol{\theta}\|_{2}^{2}\\

(当 \lambda^{\prime}=\frac{\lambda}{\alpha} ​时,与weight decay等价,仅在使用标准SGD优化时成立)


Adam+L2 regularization

Adam自动调整学习率,大幅提高了训练速度,也很少需要调整学习率,但是有相当多的资料报告Adam优化的最终精度略低于SGD。问题出在哪呢,其实Adam本身没有问题,问题在于目前大多数DL框架的L2 regularization实现用的是weight decay的方式,而weight decay在与Adam共同使用的时候有互相耦合。

adam+L2 regularization(红色); adamw(绿色)

红色是传统的Adam+L2 regularization的方式,梯度 g_t 的移动平均 m_t 与梯度平方的移动平均 v_t 都加入了 \lambda \boldsymbol{\theta_{t-1}}

line 9的 \hat{\boldsymbol{m}}_{t} 是在对于移动平均的初始时刻做修正,当t足够大时, \hat{\boldsymbol{m}}_{t}=\boldsymbol{m}_{t} 。初始时刻 t=1 时,假设 \beta_1=0.9 ,初始化\boldsymbol {m}_{0}=0 , \boldsymbol {m}_1=0.9 \cdot 0 + 0.1 \cdot \boldsymbol {g_1}=0.1 \boldsymbol g_1 ,这显然不合理,但是除以 1-\beta_1^t=1-0.9=0.1\hat{\boldsymbol{m}}_{t}=\boldsymbol{g}_{t} 。line 10同理,因此后面都假设t足够大,\hat{\boldsymbol{m}}_{t}=\boldsymbol{m}_{t}

如果把line 6, line 7, line 8都带入line 12,并假设 \eta_t=1 ( \alpha 为学习率):

\boldsymbol \theta_{t} \leftarrow \boldsymbol \theta_{t-1}-\alpha \frac{\beta_{1} \boldsymbol m_{t-1}+\left(1-\beta_{1}\right)\left(\nabla \boldsymbol f_{t}+\lambda \boldsymbol \theta_{t-1})\right.}{\sqrt{\hat{\boldsymbol {v}}_{t}}+\epsilon}\\

分子右上角的 \lambda \boldsymbol {\theta}_{t-1} 向量各个元素被分母的 \sqrt{\hat{\boldsymbol {v}}_{t}} 项调整了。梯度快速变化的方向上,\sqrt{\hat{\boldsymbol {v}}_{t}}有更大的值,而使得调整后的\frac {\lambda \boldsymbol {\theta}_{t-1}}{\sqrt{\hat{\boldsymbol {v}}_{t}}}更小,在这个方向上 \boldsymbol \theta 被正则化地更少。这显然是不合理的,L2 regularization与weight decay都应该是各向同性的,因此论文作者提出绿色方式来接入weight decay。也即不让\lambda \boldsymbol {\theta}_{t-1}项被 \sqrt{\hat{\boldsymbol {v}}_{t}} 调整。完成梯度下降与weight decay的解耦。

目前bert训练采用的优化方法就是adamw,对除了layernorm,bias项之外的模型参数做weight decay。

大部分的模型都会有L2 regularization约束项,因此很有可能出现adam的最终效果没有sgd的好。如果在tf里面要对不同区域的tensor做不同的L2 regularization调整的化可以参考zhuanlan.zhihu.com/p/40 先做adam后在手工更新L2 regularization梯度的方法。


参考资料:

arxiv.org/pdf/1711.0510

towardsdatascience.com/

zhuanlan.zhihu.com/p/40

编辑于 2019-05-13

文章被以下专栏收录