<深度学习优化策略-1>Batch Normalization(BN)

今天给大家带来深度学习的优化策略篇的第一篇Batch Normalization(BN)。BN可以看做对输入样本的一种约束,最大作用是加速收敛,减少模型对dropout,careful weight initialnization依赖,可以adopt higher learning rate的优势,收敛速度可以提高10倍以上。

问题提出:

深度网络参数训练时内部存在协方差偏移(Internal Covariate Shift)现象:深度网络内部数据分布在训练过程中发生变化的现象。

为什么会带来不好影响:训练深度网络时,神经网络隐层参数更新会导致网络输出层输出数据的分布发生变化,而且随着层数的增加,根据链式规则,这种偏移现象会逐渐被放大。这对于网络参数学习来说是个问题:因为神经网络本质学习的就是数据分布(representation learning),如果数据分布变化了,神经网络又不得不学习新的分布。为保证网络参数训练的稳定性和收敛性,往往需要选择比较小的学习速率(learning rate),同时参数初始化的好坏也明显影响训练出的模型精度,特别是在训练具有饱和非线性(死区特性)的网络,比如即采用S或双S激活函数网络,比如LSTM,GRU。

解决办法:引入Batch Normalization,作为深度网络模型的一个层,每次先对input数据进行归一化,再送入神经网络输入层。

Batch normalization实现:

1、使网络某一层的输入样本做白化处理(最后发现等价于零均值化(Normalization)处理,拥有零均值,方差为1),输入样本之间不相关。通过零均值化每一层的输入,使每一层拥有服从相同分布的输入样本,因此克服内部协方差偏移的影响。

E(X)是输入样本X的期望,Var是输入样本X的方差。注意,对于一个d维的输入样本X=(x1,x2,....xd),要对某一层所有的维度一起进行零均值化处理,计算量大,且部分地方不可导,因此,这里的是针对每个维度k分别处理。

2、数据简化:输入样本X规模可能上亿,直接计算几乎是不可能的。因此,选择每个batch来进行Normalization,得出Batch Normalization(BN)的处理方式:

从上图可以看出,因为简单的对数据进行normalize处理会降低,把数据限制在[0,1]的范围内,对于型激活函数来说,只使用了其线性部分,会限制模型的表达能力。所以,还需要对normalize之后的进行变换:

r和B做为参数,可以通过网络训练学到。

3、BN的参数求导:输入样本进行变换之后,要进行梯度反向传播,就必须计算BN各个参数的梯度,梯度计算公式如下:

可见,通过BN对输入样本进行变换是可导的,不会对梯度反向传播造成影响。

4、为了加快训练,我们在模型训练阶段使用BN,但是在推理阶段并不一定要使用。训练一个BN网络完整的流程:

Batch Normalization的优势总结:

ad1:减少梯度对参数大小或初始值的依赖,使网络在使用较大学习速率训练网络参数时,也不会出现参数发散的情况。过高的学习速率会导致梯度爆炸或梯度消失,或者陷入局部极小值,BN可以帮助处理这个问题:通过把梯度映射到一个值大但次优的变化位置来阻止梯度过小变化:比如,阻止训练陷入饱和死区。

BN使参数在训练时更加灵活。过大的学习速率会增加参数的规模,导致反向传播时发生梯度爆炸。BN中,反向传播时不会受梯度规模影响。

可见,权重越大,梯度越小,BN时参数变化更稳定。可以推测,当输入样本X服从高斯分布且输入元素之间相互独立时,BN可以使层参数的Jacobian拥有接近于1的单值,能够完整保留梯度的值在反向传播时不衰减。

ad2:正则化模型,减少对dropout的依赖。

BN正则化模型:使用BN训练时,一个样本只与minibatch中其他样本有相互关系;对于同一个训练样本,网络的输出会发生变化。这些效果有助于提升网络泛化能力,像dropout一样防止网络过拟合,同时BN的使用,可以减少或者去掉dropout类似的策略。

ad3:网络可以采用具有非线性饱和特性的激活函数(比如:s型激活函数),因为它可以避免网络陷入饱和状态。

参考论文:

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

发布于 2017-05-03

文章被以下专栏收录

    专注深度学习、NLP相关技术、资讯,追求纯粹的技术,享受学习、分享的快乐。欢迎扫描头像二维码或者微信搜索“深度学习与NLP”公众号添加关注,获得更多深度学习与NLP方面的经典论文、实践经验和最新消息。