Group Normalization 及其MXNet、Gluon实现

Group Normalization 及其MXNet、Gluon实现

前言

在介绍 Group Normalization(GN) 之前需要复习一下 Batch Normalization(BN),毕竟GN 是对 BN 的改进,理解了BN 将对于理解后面的GN有很大帮助。

Batch Normalization

于2015年由 Google 提出,Google在ICML论文中描述的非常清晰,即在每次SGD时,通过mini-batch来对相应的activation做规范化操作,使得结果(输出信号各个维度)的均值为0,方差为1。最后的“scale and shift”操作则是为了训练所需而“刻意”加入的BN能够有可能还原最初的输入,从而保证整个network的Capacity。

自提出以来便成为CNN网络中不可缺少的组件。Batch Normalization 具有非常多优良的性质,例如:

(1) BN使得网络中每层输入数据的分布相对稳定,加速模型学习速度;

(2) BN使得模型对网络中的参数不那么敏感,简化调参过程,使得网络学习更加稳定;

(3)) BN允许网络使用饱和性激活函数(例如sigmoid,tanh等),缓解梯度消失问题;

(4) BN具有一定的正则化效果。

Batch Normalization 的缺点

Batch Normalization(BN)独立地规范化每一个层不同批次的 x ,但规范化的参数是一个 Mini-Batch 的一阶统计量和二阶统计量。BN 沿着 batch 维度进行归一化,其受限于 Batch Size;当 Batch Size 很小时,BN 会得到不准确的统计估计,会导致模型误差明显增加. 一般每块 GPU 上 Batch Size =32 最合适。但对于目标检测,语义分割,视频场景等,输入图像尺寸比较大,而限于GPU显卡的显存限制,导致无法设置较大的 Batch Size,如 经典的Faster-RCNN、Mask R-CNN 网络中,由于图像的分辨率较大,Batch Size 只能是 1 或 2.

另一方面,Batch Normalization是在Batch这个维度上Normalization,但是这个维度并不是固定不变的,比如训练和测试时一般不一样,一般都是训练的时候在训练集上通过滑动平均预先计算好平均-mean,和方差-variance参数,在测试时,不再计算这些值,而是直接调用这些预计算好的来用. 但是,当训练数据和测试数据分布存在差异时,预训练的均值和方差并不能真实反映测试集,这就导致在训练,验证,测试这三个阶段存在Inconsistency。

Group Normalization

Group Normalization(GN)是由2018年3月份何恺明团队提出,GN优化了BN在比较小的mini-batch情况下表现不太好的劣势。批量维度进行归一化会带来一些问题——批量统计估算不准确导致批量变小时,BN 的误差会迅速增加。在训练大型网络和将特征转移到计算机视觉任务中(包括检测、分割和视频),内存消耗限制了只能使用小批量的BN。而小的Batch Size 则会导致Batch Normalization 失效。

Group Normalization(GN) 则是提出的一种 BN 的替代方法,其是首先将 Channels 划分为多个 groups,再计算每个 group 内的均值和方法,以进行归一化。 GN 的计算与 Batch Size 无关,且对于不同的 Batch Size ,精度都比较稳定。 另外,GN 易于从 pre-trained 模型进行 fine-tuning。 GN 和 BN 对比如图

GN替换BN的实验

根据作者在论文中给出的 GN 的实现的tensorflow伪代码, 其计算方式非常容易理解,代码如下:

def GroupNorm(x, gamma, beta, G, eps=1e-5):
    # x: input features with shape [N,C,H,W]
    # gamma, beta: scale and offset, with shape [1,C,1,1]
    # G: number of groups for GN
    N, C, H, W = x.shape
    x = tf.reshape(x, [N, G, C // G, H, W])
    mean, var = tf.nn.moments(x, [2, 3, 4], keep dims=True)
    x = (x - mean) / tf.sqrt(var + eps)
    x = tf.reshape(x, [N, C, H, W])
    return x * gamma + beta

因为在工作中主要用 MXNet 作为深度学习工具,但是在网上翻了一遍,没有找到MXNet的开源代码于是乎便想着在MXNet 中实现这一方法。下面我将网上关于 Group Normalization 的各个框架的实现进行归纳。然后给出 MXNet 的 Module 接口及 Gluon 接口的实现。

GN 的 TF 实现

def norm(x, norm_type, is_train, G=32, esp=1e-5):
    with tf.variable_scope('{}_norm'.format(norm_type)):
        if norm_type == 'none':
            output = x
        elif norm_type == 'batch':
            output = tf.contrib.layers.batch_norm(
                x, center=True, scale=True, decay=0.999,
                is_training=is_train, updates_collections=None)
        elif norm_type == 'group':
            # normalize
            # tranpose: [bs, h, w, c] to [bs, c, h, w] following the paper
            x = tf.transpose(x, [0, 3, 1, 2])
            N, C, H, W = x.get_shape().as_list()
            G = min(G, C)
            x = tf.reshape(x, [-1, G, C // G, H, W])
            mean, var = tf.nn.moments(x, [2, 3, 4], keep_dims=True)
            x = (x - mean) / tf.sqrt(var + esp)
            # per channel gamma and beta
            gamma = tf.Variable(tf.constant(1.0, shape=[C]), dtype=tf.float32, name='gamma')
            beta = tf.Variable(tf.constant(0.0, shape=[C]), dtype=tf.float32, name='beta')
            gamma = tf.reshape(gamma, [1, C, 1, 1])
            beta = tf.reshape(beta, [1, C, 1, 1])

            output = tf.reshape(x, [-1, C, H, W]) * gamma + beta
            # tranpose: [bs, c, h, w, c] to [bs, h, w, c] following the paper
            output = tf.transpose(output, [0, 2, 3, 1])
        else:
            raise NotImplementedError
    return output

GN 的 Pytorch 实现

import numpy as np
import torch
import torch.nn as nn


class GroupNorm(nn.Module):
    def __init__(self, num_features, num_groups=32, eps=1e-5):
        super(GroupNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(1,num_features,1,1))
        self.bias = nn.Parameter(torch.zeros(1,num_features,1,1))
        self.num_groups = num_groups
        self.eps = eps

    def forward(self, x):
        N,C,H,W = x.size()
        G = self.num_groups
        assert C % G == 0

        x = x.view(N,G,-1)
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, keepdim=True)

        x = (x-mean) / (var+self.eps).sqrt()
        x = x.view(N,C,H,W)
        return x * self.weight + self.bias

Group Normalization 在 MXNet中的实现

首先通过 MXNet 中自定义OP 实现了GroupNorm 的运算,关于如何在 MXNet中使用自定义Op,请参考官方使用文档 How to Create New Operators (Layers)¶ 。自定义Op 的 一个主要的难点是要重写 Forward 和 Backward 函数。

import numpy as np
import mxnet as mx

class GroupNorm(mx.operator.CustomOp):
    """
    If the batch size is small, it's better to use GroupNorm instead of BatchNorm.
    GroupNorm achieves good results even at small batch sizes.
    Reference:
      https://arxiv.org/pdf/1803.08494.pdf
    """
    def __init__(self, gamma, beta, num_groups=32, eps=1e-5,**kwargs):
        super(GroupNorm, self).__init__(**kwargs)
        self.gamma = gamma
        self.beta = beta
        self.G = num_groups
        self.eps = eps
        self.mean = None
        self.var = None
        self.x_norm = None


    def forward(self, is_train, req, in_data, out_data, aux):
        """
        Computes the forward pass for spatial group normalization.
        In contrast to layer normalization, group normalization splits each entry 
        in the data into G contiguous pieces, which it then normalizes independently.
        Per feature shifting and scaling are then applied to the data, in a manner identical 
        to that of batch normalization and layer normalization.


        Inputs:
        - x: Input data of shape (N, C, H, W)
        - gamma: Scale parameter, of shape (C,)
        - beta: Shift parameter, of shape (C,)
        - G: Integer mumber of groups to split into, should be a divisor of C
        - gn_param: Dictionary with the following keys:
        - eps: Constant for numeric stability   

        Returns a tuple of:
        - out: Output data, of shape (N, C, H, W)                                        
        """

        x = in_data[0]
        N,C,H,W = x.shape
        # group the channel by G
        x_group = x.reshape((N, self.G, -1, H, W))
        self.mean = mx.nd.mean(x_group, axis=(2, 3, 4), keepdims=True) 
        self.var = mx.nd.mean((x_group self.mean)**2, axis = (2, 3, 4), keepdims=True) 
        # Normalization
        x_groupnorm = (x_group - self.mean) / mx.nd.sqrt(self.var + self.eps)
        # reshape to (N,C,H,W)
        self.x_norm = x_groupnorm.reshape((N,C,H,W))
        # output the group normalization result
        x_gn = self.x_norm * self.gamma + self.beta
        self.assign(out_data[0], req[0], x_gn)

    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
        """
        Implement the backward pass for spatial group normalization.      
        This will be extremely similar to the layer norm implementation. 

        Inputs:
        - dout: Upstream derivatives, of shape (N, C, H, W)

        Returns a tuple of:
        - dx: Gradient with respect to inputs, of shape (N, C, H, W)
        - dgamma: Gradient with respect to scale parameter, of shape (C,)
        - dbeta: Gradient with respect to shift parameter, of shape (C,)

        """       
        dx, dgamma, dbeta = None, None, None 
        x = in_data[0]
        dout = out_grad[0]
        N,C,H,W = dout.shape
        # dbeta, dgamma 
        dbeta = mx.nd.sum(dout, axis=(0,2,3), keepdims=True) 
        dgamma = mx.nd.sum(dout*self.x_norm, axis=(0,2,3), keepdims=True)
        # get dx_group,(N, G, C // G, H, W)
        # dx_groupnorm
        dx_norm = dout * self.gamma 
        dx_groupnorm = dx_norm.reshape((N, self.G, C // self.G, H, W)) 
        # dvar
        x_group = x.reshape((N, self.G, C //self.G, H, W))
        dvar = mx.nd.sum(dx_groupnorm * -1.0 / 2 * (x_group - self.mean) / (self.var + self.eps) ** (3.0 / 2), axis=(2,3,4), keepdims=True)
        # dmean
        N_GROUP = C//self.G*H*W
        dmean1 = mx.nd.sum(dx_groupnorm * -1.0 / mx.nd.sqrt(self.var + self.eps), axis=(2,3,4), keepdims=True)
        dmean2_var = dvar * -2.0 / N_GROUP * mx.nd.sum(x_group - self.mean, axis=(2,3,4), keepdims=True)
        dmean = dmean1 + dmean2_var
        # dx_group
        dx_group1 = dx_groupnorm * 1.0 / mx.nd.sqrt(self.var + self.eps)
        dx_group2_mean = dmean * 1.0 / N_GROUP
        dx_group3_var = dvar * 2.0 / N_GROUP * (x_group - self.mean)
        dx_group = dx_group1 + dx_group2_mean + dx_group3_var
        # reshape 
        dx = dx_group.reshape((N, C, H, W)) 
        self.assign(in_grad[0], req[0], dx)

@mx.operator.register("groupnorm")  # register with name "groupnorm"
class GroupNormProp(mx.operator.CustomOpProp):
    def __init__(self, gamma, beta, num_groups=32, eps=1e-5,**kwargs):
        super(GroupNormProp, self).__init__(need_top_grad=True)
        """
        All arguments are in string format so you need
        to convert them back to the type you want.
        """
        self.gamma = float(gamma)
        self.beta = float(beta)
        self.G = int(num_groups)
        self.eps = float(eps)

    def list_arguments(self):
        return ['data']

    def list_outputs(self):
        #this can be omitted if you only have 1 output.
        return ['output']

    def infer_shape(self, in_shapes):
        data_shape = in_shapes
        output_shape = data_shape
        #return 3 lists representing inputs shapes, outputs shapes, and aux data shapes.
        return data_shape, output_shape, []

    def infer_type(self, in_type):
        dtype = in_type
        return (dtype), (dtype), ()

    def create_operator(self, ctx, shapes, dtypes):
        # create and return the CustomOp class.
        return GroupNorm(self.gamma, self.beta, self.G, self.eps)

Group normalization 的 Gluon 实现

上面的实现是基于 MXNet 的原生API来实现的,可以在 MXNet 的 Symbol 和 NDArray中直接调用。对于 MXNet 的 Gluon 模块,下面同样给出其实现代码。

class GroupNorm(nn.HybridBlock):
    """
    If the batch size is small, it's better to use GroupNorm instead of BatchNorm.
    GroupNorm achieves good results even at small batch sizes.
    Reference:
      https://arxiv.org/pdf/1803.08494.pdf
    """
    def __init__(self, num_channels, num_groups=32, eps=1e-5,
                 multi_precision=False, **kwargs):
        super(GroupNorm, self).__init__(**kwargs)

        with self.name_scope():
            self.weight = self.params.get('weight', grad_req='write',
                                          shape=(1, num_channels, 1, 1))
            self.bias = self.params.get('bias', grad_req='write',
                                        shape=(1, num_channels, 1, 1))
        self.C = num_channels
        self.G = num_groups
        self.eps = eps
        self.multi_precision = multi_precision

        assert self.C % self.G == 0

    def hybrid_forward(self, F, x, weight, bias):

        x_new = F.reshape(x, (0, self.G, -1))                                # (N,C,H,W) -> (N,G,H*W*C//G)

        if self.multi_precision:
            mean = F.mean(F.cast(x_new, "float32"),
                          axis=-1, keepdims=True)                            # (N,G,H*W*C//G) -> (N,G,1)
            mean = F.cast(mean, "float16")
        else:
            mean = F.mean(x_new, axis=-1, keepdims=True)

        centered_x_new = F.broadcast_minus(x_new, mean)                      # (N,G,H*W*C//G)

        if self.multi_precision:
            var = F.mean(F.cast(F.square(centered_x_new),"float32"),
                         axis=-1, keepdims=True)                             # (N,G,H*W*C//G) -> (N,G,1)
            var = F.cast(var, "float16")
        else:
            var = F.mean(F.square(centered_x_new), axis=-1, keepdims=True)

        x_new = F.broadcast_div(centered_x_new, F.sqrt(var + self.eps)       # (N,G,H*W*C//G) -> (N,C,H,W)
                                ).reshape_like(x)
        x_new = F.broadcast_add(F.broadcast_mul(x_new, weight),bias)
        return x_new

通过自定义计算图实现GroupNorm

如果你喜欢更加简单明了的实现方式,下面我还给出了还利用 MXNet 原生API 的Symbol接口自定义网络层的方式实现了GroupNorm 的运算。并且该方式对于速度没有减少,且不用考虑如何写Backward函数,代码简单高效。

def GroupNorm(self, data, in_channel, name, num_groups=32, eps=1e-5):
    """
    If the batch size is small, it's better to use GroupNorm instead of BatchNorm.
    GroupNorm achieves good results even at small batch sizes.
    Reference:
      https://arxiv.org/pdf/1803.08494.pdf
    """
    # x: input features with shape [N,C,H,W]
    # gamma, beta: scale and offset, with shape [1,C,1,1] # G: number of groups for GN
    C = in_channel
    G = num_groups
    G = min(G, C)
    x_group= mx.sym.reshape(data = data, shape = (1, G, C//G, 0, -1))
    mean = mx.sym.mean(x_group, axis= (2, 3, 4), keepdims = True) 
    differ = mx.sym.broadcast_minus(lhs = x_group, rhs = mean)
    var = mx.sym.mean(mx.sym.square(differ), axis = (2, 3, 4), keepdims =True)
    x_groupnorm = mx.sym.broadcast_div(lhs = differ, rhs = mx.sym.sqrt(var + eps))
    x_out = mx.sym.reshape_like(x_groupnorm, data)
    gamma = mx.sym.Variable(name = name + '_gamma',shape = (1,C,1,1), dtype='float32')
    beta = mx.sym.Variable(name = name + '_beta', shape=(1,C,1,1), dtype='float32')
    gn_x = mx.sym.broadcast_mul(lhs = x_out, rhs = gamma)
    gn_x = mx.sym.broadcast_plus(lhs = gn_x, rhs = beta)
    return gn_x

以上代码已经放到了我的 Github 上,欢迎使用和提意见。

jianzhnie/GroupNorm-MXNetgithub.com图标

编辑于 2019-02-03

文章被以下专栏收录