[ICCV 2019 Oral] 期望最大化注意力网络 EMANet 详解

[ICCV 2019 Oral] 期望最大化注意力网络 EMANet 详解

转载请私信或邮件(ethanlee@pku.edu.cn)

本文介绍笔者被ICCV 2019接受为Oral的论文 Expectation-Maximization Attention Networks for Semantic Segmentation。作者:李夏钟之声吴建龙杨一博林宙辰刘宏。代码已开源在:github.com/XiaLiPKU/EMA

背景介绍

语义分割是计算机视觉领域的一项基础任务,它的目标是为每个像素预测类别标签。由于类别多样繁杂,且类间表征相似度大,语义分割要求模型具有强大的区分能力。近年来,基于全卷积网络(FCN[1])的一系列研究,在该任务上取得了卓越的成绩。

这些语义分割网络,由骨干网络和语义分割头组成。全卷积网络受制于较小的有效感知域,无法充分捕获长距离信息。为弥补这一缺陷,诸多工作提出提出了高效的多尺度上下文融合模块,例如全局池化层、Deeplab[2]的空洞空间卷积池化金字塔、PSPNet[3]的金字塔池化模块等。

近年来,自注意力机制在自然语言处理领域取得卓越成果。Nonlocal[4]被提出后,在计算机视觉领域也受到了广泛的关注,并被一系列文章证明了在语义分割中的有效性。它使得每个像素可以充分捕获全局信息。然而,自注意力机制需要生成一个巨大的注意力图,其空间复杂度和时间复杂度巨大。其瓶颈在于,每一个像素的注意力图都需要对全图计算。

本文所提出的期望最大化注意力机制(EMA),摒弃了在全图上计算注意力图的流程,转而通过期望最大化(EM)算法迭代出一组紧凑的基,在这组基上运行注意力机制,从而大大降低了复杂度。其中,E步更新注意力图,M步更新这组基。E、M交替执行,收敛之后用来重建特征图。本文把这一机制嵌入网络中,构造出轻量且易实现的EMA Unit。其作为语义分割头,在多个数据集上取得了较高的精度。

期望最大化注意力机制

前提知识

期望最大化算法

期望最大化(EM)算法旨在为隐变量模型寻找最大似然解。对于观测数据 \mathbf{X} = \left\{ \mathbf{x}_1, \mathbf{x}_2, \dots, \mathbf{x}_N \right\} ,每一个数据点 \mathbf{x}_i 都对应隐变量 \mathbf{z}_i 。我们把 \left\{ \mathbf{X}, \mathbf{Z} \right\} 称为完整数据,其似然函数为 \ln p \left( \mathbf{X}, \mathbf{Z} | \bm{\theta} \right)\bm{\theta} 是模型的参数。

E步根据当前参数 \theta^{old} 计算隐变量 \mathbf{Z} 的后验分布,并以之寻找完整数据的似然\mathcal{Q} \left( \bm{\theta}, \bm{\theta}^{old} \right)

\mathcal{Q} \left( \bm{\theta}, \bm{\theta}^{old} \right) = \sum_{\mathbf{z}} p \left(  \mathbf{Z} | \mathbf{X}, \bm{\theta}^{old} \right) \ln p \left( \mathbf{X}, \mathbf{Z} | \bm{\theta} \right) \tag{1}

M步通过最大化似然函数来更新参数得到 \theta^{new}

\bm{\theta}^{\mathrm{new}} = \underset{\bm{\theta}}{\arg\max} \mathcal{Q} \left( \bm{\theta}, \bm{\theta}^{\mathrm{old}} \right)\tag{2}

EM算法被证明会收敛到局部最大值处,且迭代过程完整数据似然值单调递增。

高斯混合模型(GMM)是EM算法的一个范例,它把数据用多个高斯分布拟合。其 \bm{\theta}_k 即为第 k 个高斯分布的参数 \bm{\mu}_k, \Sigma_k ,隐变量 \mathbf{z}_{nk} 为第 k 个高斯分布对第 n 数据点的“责任”。E步更新“责任”,M步更新高斯参数。在实际应用中, \Sigma_k 经常被简化为 \mathbf{I}

非局部网络

非局部网络(Nonlocal)率先将自注意力机制使用在计算机视觉任务中。其核心算子是:

\mathbf{y}_i = \frac {1} {C \left( \mathbf{x} \right)} \sum_{\forall j} f \left( \mathbf{x}_i, \mathbf{x}_j \right) g \left( \mathbf{x}_j \right) \tag{3}

其中 f \left(\cdot, \cdot\right) 表示广义的核函数, \mathcal{C} \left( \mathbf{x}  \right) 是归一化系数。它将第 i 个像素的特征 \mathbf{x}_i 更新为其他所有像素特征经过 g 变换之后的加权平均 \mathbf{y}_i ,权重通过归一化后的核函数计算,表征两个像素之间的相关度。这里 1 < j < N ,所以视为像素特征被一组过完备的基进行了重构。这组基数目巨大,且存在大量信息冗余。

期望最大化注意力机制

期望最大化注意力机制由 A_E, A_M, A_R 三部分组成,前两者分别对应EM算法的E步和M步。

假定输入的特征图为 \mathbf{X} \in R^{N \times C},基初始值为 \bm{\mu} \in R^{K \times C}A_E 估计隐变量 \mathbf{Z} \in R^{N \times K},即每个基对像素的权责。具体地,第 k 个基对第 n 个像素的权责可以计算为:

z_{nk} = \frac { \mathcal{K} \left( \mathbf{x}_n, \bm{\mu}_k \right) } { \sum^K_{j=1} \mathcal{K} \left( \mathbf{x}_n, \bm{\mu}_j \right) } \tag{4}

在这里,内核 \mathcal{K} \left( \mathbf{a}, \mathbf{b} \right) 可以有多种选择。我们选择 \exp \left( \mathbf{a}^{\top} \mathbf{b} \right) 的形式。在实现中,可以用如下的方式实现:

\mathbf{Z} = softmax \left( \lambda \mathbf{X} \left( \bm{\mu}^{\top} \right) \right) \tag{5}

其中, \lambda 作为超参数来控制 \mathbf{Z} 的分布。

A_M 步更新基 \bm{\mu} 。为了保证 \bm{\mu}\mathbf{X} 处在同一表征空间内,此处 \bm{\mu} 被计算作 \mathbf{X} 的加权平均。具体地,第 k 个基被更新为:

\bm{\mu}_k = \frac { \sum^N_{n=1} z_{nk} \mathbf{x}_n } { \sum^N_{n=1} z_{nk} } \tag{6}

值得注意的是,如果 \lambda \to \infty ,则公式(5)中, \left\{ z_{n1}, z_{n2}, \dots , z_{nK} \right\} 会变成一组one-hot编码。在这种情形下,每个像素仅由一个基负责,而基被更新为其所负责的像素的均值,这便是标准的K-means算法。

A_EA_M 交替执行 T 步。此后,近似收敛的 \bm{\mu}\mathbf{Z} 便可以被用来对 \mathbf{X} 进行重估计得 \tilde{\mathbf{X}}

\tilde{\mathbf{X}} = \mathbf{Z} \bm{\mu} \tag{7}

\tilde{\mathbf{X}} 相比于 \mathbf{X} ,具有低秩的特性。从下图中可看出,其在保持类间差异的同时,类别内部差异得到缩小。从图像角度来看,起到了类似保边滤波的效果。

综上,EMA在获得低秩重构特性的同时,将复杂度从Nonlocal的 O \left( N^2 \right) 降低至 O \left( NKT \right) 。实验中,EMA仅需 3 步就可达到近似收敛,因此 T 作为一个小常数,可以被省去。至此,EMA复杂度仅为 O \left( NK \right) 。考虑到 K \ll N ,其复杂度得到显著的降低。

期望最大化注意力模块

EMA Unit

期望最大化注意力模块(EMAU)的结构如上图所示。除了核心的EMA之外,两个 1\times  1 卷积分别放置于EMA前后。前者将输入的值域从 R^+ 映射到 R ;后者将 \tilde{\mathbf{X}} 映射到 \mathbf{X} 的残差空间。囊括进两个卷积的额外负荷,EMAU的FLOPs仅相当于同样输入输出大小时 3\times 3 卷积的 1/3 ,参数量仅为 2C^2 + KC

对于EM算法而言,参数的初始化会影响到最终收敛时的效果。上一节中讨论了EMA如何在单张图像的特征图上进行迭代运算。而对于深度网络训练过程中的大量图片,在逐个批次训练的同时,EM参数的迭代初值 \bm{\mu}^{\left(0\right)} 理应得到不断优化。本文中,迭代初值 \bm{\mu}^{\left(0\right)}的维护参考BN中running_mean和running_std的滑动平均更新方式,即:

\bm{\mu}^{\left(0\right)} \leftarrow \alpha \bm{\mu}^{\left(0\right)} + \left( 1 - \alpha \right) \bm{\bar\mu}^{\left(T\right)} \tag{8}

其中, \alpha \in \left[ 0, 1 \right] 表示动量;\bm{\bar\mu}^{\left(T\right)} 表示 \bm{\mu}^{\left(T\right)} 在一个mini-batch上的平均。

此外,EMA的迭代过程可以展开为一个RNN,其反向传播也会面临梯度爆炸或消失等问题。此外,公式(8)也要求 \bm{\mu}^{\left(0\right)}\bm{\bar\mu}^{\left(T\right)} 的差异不宜过大,不然初值 \bm{\mu}^{\left(0\right)} 的更新也会出现不稳定。RNN中采取LayerNorm(LN)来进行归一化是一个合理的选择。但在EMA中,LN会改变基的方向,进而影响其语义。因为,本文选择L2Norm来对基进行归一化。这样,\bm{\mu}^{\left(0\right)} 的更新轨迹便处在一个高维球面上。

此处,我们可以考虑下EMA和A2Net[5]的关联。A2Net的核心算子如下:

\mathbf{Y}\!=\left[\phi\!\left(\!\mathbf{X}, W_{\phi} \right) softmax\left(\theta \left( \mathbf{X}, W_{\theta}\right)\right)^\top\!\right]softmax\left( \rho \left( \mathbf{X}, W_{\rho} \right) \right)\  \tag{9}

其中 \theta, \phi, \sigma 代表三个 1\times 1 卷积,它们的参数分别为 W_{\theta}W_{\phi}W_{\sigma} 。如果我们将 \theta\phi 的参数共享,并将W_{\theta}W_{\phi} 记作 \bm{\mu} 。那么, softmax \left( \theta \left( \mathbf{X}, W_{\theta} \right) \right) 和公式(5)无异;而 \left[ \cdot \right] 内则在更新 \bm{\mu} ,即相当于 A_EA_M 迭代一次。因此,A2-Block可以看作EMAU的特殊例子,它只迭代一次EM,且 \bm{\mu} 由反向传播来更新。而EMAU迭代 T 步,用滑动平均来更新 \bm{\mu}

实验

首先是在PASCOL VOC上的消融实验。这里对比了不同的 \bm{\mu} 更新方法和归一化方法的影响。

可以清楚地看到,EMA使用滑动均值(Moving average)和L2Norm最为有效。作为对比,Nonlocal和A2Net的模块作为语义分割头,在同样设置下分别达到 77.78%和77.34%的分数,而EMANet仅迭代一次时分数为77.34%,三者无显著差异,符合上文对Nonlocal和A2Net的分析和对比。接下来是不同训练和测试中迭代次数 T 的对比实验。

可以发现,EMA仅需三步即可近似收敛(精度不再增益)。而随着训练时迭代次数的继续增长,精度有所下降,这是由EMA的RNN特性引起的。

接下来,是EMANet和DeeplabV3、DeeplabV3+和PSANet的详细对比。

可以发现,EMANet无论在精度还是在计算代价上,都显著高于表中几个经典算法。

在VOC test server上,EMANet在所有使用ResNet-101的算法中,取得了最高的分数。此外,在PASCAL Context和COCO stuff数据集上也表现卓越。

最后是学习到的注意力图的可视化。如下图, i, j, k, l 表示四个随机选择的基的下标。右边四列绘出的是它们各自对应的注意力图。可以看到,不同的基会收敛到一些特定的语义概念。

参考

  1. ^FCN https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf
  2. ^Deeplab https://arxiv.org/abs/1606.00915
  3. ^PSPNet https://arxiv.org/abs/1612.01105
  4. ^Non-local Neural Networks https://arxiv.org/abs/1711.07971
  5. ^A2Net https://papers.nips.cc/paper/7318-a2-nets-double-attention-networks.pdf
编辑于 2019-09-11

文章被以下专栏收录