深度学习推理时融合BN,轻松获得约5%的提速

深度学习推理时融合BN,轻松获得约5%的提速

批归一化(Batch Normalization)因其可以加速神经网络训练、使网络训练更稳定,而且还有一定的正则化效果,所以得到了非常广泛的应用。但是,在推理阶段,BN层一般是可以完全融合到前面的卷积层的,而且丝毫不影响性能。
本文首发:AIZOO

Batch Normalization是谷歌研究员于2015年提出的一种归一化方法,其思想非常简单,一句话概括就是,对一个神经元(或者一个卷积核)的输出减去统计得到的均值除以标准差然后乘以一个可学习的系数,再加上一个偏置,这个过程就完成了。

下面我们简单介绍一下BN训练时怎么做,推理的时候为什么可以融合,以及怎么样融合。

一. BN训练时如何做

训练过程中BN层的运算,用公式表达也很简单,对于一个Batch内的第 i 个样本,假设某个神经元的输出是 x_{i} , 则经过 BN 层后的输出 也就是:

 y_{i}=\gamma \frac{x_{i}-\mu}{\sqrt{\sigma ^{2}+\varepsilon }}+\beta

其中 \mu 为一个Batch内 的均值, \sigma^2 为一个Batch内的 的标准差, \epsilon 为一个非常小的常数,例如 0.001, 主要是为了避免除零错误,均值和方差的计算方法分别为:

                         \mu=\frac{1}{m}\sum_{i=1}^{m}x_{i}

\sigma^2=\frac{1}{m}\sum_{i=1}^{m}(x_{i}-\mu)^2

\gamma\beta 是一个可学习的参数,在训练过程中,和其他卷积核的参数一样,通过梯度下降来学习。

在训练过程中,为保持稳定,一般使用滑动平均法更新均值和方差,滑动平均就是在更新当前值的时候,以一定比例保存之前的数值,以均值 \mu 为例,以一定比例 \theta (例如这里0.99)保存之前的均值,当前只更新 (1-  \theta ) 倍(也就是0.001倍)的本Batch的均值,计算方法如下:

\mu_{i}=\theta\mu_{i-1} + (1-\theta)\mu_{i}

标准差的滑动平均计算方法也一样。

二. BN推理时怎么做

大家要注意的是,在训练的时候,均值 \mu 、方差 \sigma^2\gamma\beta 是一直在更新的,但是,在推理的时候,以上四个值都是固定了的,也就是推理的时候,均值和方差来自训练样本的数据分布。

因此,在推理的时候,上面BN的计算公式可以变形为:

y_{i}=\gamma \frac{x_{i}-\mu}{\sqrt{\sigma ^{2}+\varepsilon }))}+\beta =  \frac{\gamma}{\sqrt{\sigma ^{2}+\varepsilon }}x_{i}+(\beta - \frac{\gamma\mu}{\sqrt{\sigma ^{2}+\varepsilon }))})

大家应该可以发现,在均值 \mu 、方差 \sigma^2\gamma\beta 都是固定值的时候,上面公式可以改写为

y_{i} = ax_{i}+b

其中, a=\frac{\gamma}{\sqrt{\sigma ^{2}+\varepsilon }} , b=\beta - \frac{\gamma\mu}{\sqrt{\sigma ^{2}+\varepsilon }} , 在推理的时候,都是固定不变的常数。我们以一个三个神经元输入的全连接网络为例,如下图:

三个输入的全连接网络图

则全连接输出:

x_{i}=w_{1}\cdot z_{1}+w_{2}\cdot z_{2}+w_{3}\cdot z_{3}+c

x_{i}=w_{1}\cdot z_{1}+w_{2}\cdot z_{2}+w_{3}\cdot z_{3}+c

其中 c 为偏置(这里为避免与上面的 b 冲突,所以用 c 表示),那么全连接 + BN 一起,则是

y_{i} = ax_{i}+b=a(w_{1}\cdot z_{1}+w_{2}\cdot z_{2}+w_{3}\cdot z_{3}+c)+b

也就是

y_{i} = aw_{1}\cdot z_{1}+aw_{2}\cdot z_{2}+aw_{3}\cdot z_{3}+(ac+b)

到这里大家应该清楚了,因为推理时,BN是一个线性的操作,也就是一个缩放+一个偏移,我们完全可以把这个线性操作叠加到前面的全连接层或者卷积层,只需要把全连接或者卷积层的权重乘以一个系数 a ,偏置从 c 变为 ac+b 就可以了了。完整的过程如下图:

Conv+BN融合到BN

三. 在框架中如何融合

在训练时候,在卷积层后面直接加BN层,训练完成后,我们只需要将网络中BN层去掉,读取原来的卷积层权重和偏置,以及BN层的四个参数(均值 \mu 、方差 \sigma^2\gamma\beta ),然后按照上面的计算方法替换卷积核的权重,更新偏置就可以了。

下面是来自博文[1]中的一个PyTorch例子,将ResNet18中一个卷积+BN层融合后,融合前后输出的差值为-6.10425390790148e-11,也就是误差在百亿分之一,基本就是0了。

    import torch
    import torchvision
    
    def fuse(conv, bn):
    
        fused = torch.nn.Conv2d(
            conv.in_channels,
            conv.out_channels,
            kernel_size=conv.kernel_size,
            stride=conv.stride,
            padding=conv.padding,
            bias=True
        )
    
        # setting weights
        w_conv = conv.weight.clone().view(conv.out_channels, -1)
        w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
        fused.weight.copy_( torch.mm(w_bn, w_conv).view(fused.weight.size()) )
        
        # setting bias
        if conv.bias is not None:
            b_conv = conv.bias.mul(bn.weight).div(
                                          torch.sqrt(bn.running_var + bn.eps)
                                          )
        else:
            b_conv = torch.zeros( conv.weight.size(0) )
        b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
                              torch.sqrt(bn.running_var + bn.eps)
                            )
        fused.bias.copy_( b_conv + b_bn )
    
        return fused
    
    # Testing
    # we need to turn off gradient calculation because we didn't write it
    torch.set_grad_enabled(False)
    x = torch.randn(16, 3, 256, 256)
    resnet18 = torchvision.models.resnet18(pretrained=True)
    # removing all learning variables, etc
    resnet18.eval()
    model = torch.nn.Sequential(
        resnet18.conv1,
        resnet18.bn1
    )
    f1 = model.forward(x)
    fused = fuse(model[0], model[1])
    f2 = fused.forward(x)
    d = (f1 - f2).mean().item()
    print("error:",d)

因为这么一个线性操作的转换,如果有误差,那才真是见鬼了呢。

关于其他框架,如Keras、Caffe、TensorFlow的操作,与PyTorch基本一个原理,大家可以自己试验一下。

笔者在测试时候,发现融合掉BN后,会有大概5%的提速,而且还可以减小显存消耗,又丝毫不影响误差,何乐而不为呢。

但是,融合BN仅限于Conv+BN或者是BN+Conv结构,中间不能加非线性层,例如Conv+ReLu+BN那就不行了。当然,一般结构都是Conv+BN+ReLu结构。

本文完,喜欢的朋友欢迎关注、点赞、转发,三联支持哦。


精彩推荐

都2020年了,在校学生还值得继续转行搞AI吗

应届算法岗,选择巨头还是AI明星创业公司

2020年代,中国AI创业公司将走向何方

AIZOO开源人脸口罩检测数据+模型+代码+在线网页体验,通通都开源了

人脸口罩检测现开源PyTorch、TensorFlow、MXNet等全部五大主流深度学习框架模型和代码

编辑于 03-30

文章被以下专栏收录

    分享人工智能的技术和资讯,另外,我也做了一个网站AIZOO.com,欢迎大家去在线体验好玩的人工智能算法。