CVPR 2020 | TRAML:利用自适应边际损失增强小样本学习

CVPR 2020 | TRAML:利用自适应边际损失增强小样本学习

简单介绍一下我们被 CVPR 2020 录用的一个小样本学习的工作《Boosting Few-Shot Learning With Adaptive Margin Loss》(代码稍后放出),论文地址如下:

Boosting Few-Shot Learning With Adaptive Margin Losswww.weiranhuang.com图标

1、背景

近年来,在深度网络的帮助下,计算机已经在图像识别上取得了超越人类的效果。然而,深度神经网络参数量巨大,因此需要大量有标签的数据来训练。现实世界中,有很多场景没有这么多的标注数据,获取标注数据的成本也非常大,例如在医疗领域、安全领域、终端用户手动标注等。相比之下,借助于之前丰富的知识积累,人类只需看一次就能轻松识别出新的类别。受到人类这种利用少量样本即可识别新类能力的启发,人们开始研究小样本学习问题:在没有大量训练数据的条件下,如何让深度神经网络也能够像人类一样,把过往的经验迁移到新的类别上。

在小样本学习问题中,我们假设有一组基类,以及一组新类。每个基类具有足够的训练样本,而每个新类只有少量标记样本。 小样本学习的目的是通过从基类转移知识来学习识别具有少量标注样本的新类别。常见的 setting 有如下两种:

  1. 标准小样本:给定一个大规模的训练集作为基类(base class),可以类比于人类的知识积累,对于从未见过的新类(novel class,与基类不重叠),借助每类少数几个训练样本,需要准确识别新类的测试样本。
  2. 广义小样本:相比与小样本学习,广义小样本学习中测试样本不仅包含新类,还可以包含基类,因此更具挑战性。

基于度量学习的元学习方法在小样本学习上表现出了很好的性能,它通过学习一个好的特征表示,使得在特征空间中,同类样本聚集,异类样本分开。这样,只需简单通过和各类样本的距离比较,就能预测一个新类样本所属的类别。

特征空间聚类结构和自适应边际示意图

在已有的方法中,相似类别的样本常常在特征空间里的距离挨得很近以至于难以区分,大大限制了分类精度。本文提出在类别之间加入自适应的边际距离来提升基于度量学习的元学习性能,其中的边际距离是通过类别之间的语义相似度自动生成的。直观上,语义上越相似的类别之间越难区分,设定的边际距离也应该越大。大量的实验表明,本文的方法在标准小样本分类和广义小样本分类任务上都显著超越了现有的方法。

2、回顾——基于度量学习的元学习方法

元学习(meta-learning)是一种处理小样本学习的常用框架,它包含 meta-training 和 meta-testing 两个阶段。在 meta-training 阶段,模型按照一个个 episode 来训练:在每个 episode 中,首先构造一个 task(从整个 base class 数据集中抽取一些样本来组成一个小训练集和小测试集) ,然后用它来更新模型。在 meta-testing 阶段,我们用学到的模型来预测 novel class 中的样本。近年来,基于度量学习(metric learning)的元学习方法变得很流行,它假设存在这样一个 embedding space,每个类别的样本聚类在一个代表点(class representation)周围,而这些类代表点当作每个类的参考点来预测测试样本的标签(比如距离测试点最近的 class representation 所对应的类别作为测试样本的标签)。【可以参看 Prototypical NetworkGlobal Class Representation 两个文章】

损失函数

在 meta-training 阶段,我们从 base class 数据集中随机抽取 n_t 个类(类别集合记作 C_t),每个类中随机抽取 n_s 个样本,合并后作为小训练集(称为支持集 S)。然后再从每个类随机抽取若干样本,合并后作为小测试集(称为问询集 Q)。支持集和问询集中的样本通过 embedding 函数 \mathcal F 映射到特征空间中,每个类的代表点用 r_1,\dots,r_{n_t} 来表示(比如,每个类所有支持集中的样本的中心点作为代表点)。最后引入一个度量模块 \mathcal D 来衡量两个特征向量的相似度(比如 cosine similarity)。

对于每个问询集中的点 (x,y)\in Q ,我们计算它的特征向量到各个类代表点之间到相似度,然后通过 softmax 计算损失函数如下

3、自适应边际损失(Adaptive Margin Loss)

为了更好地分开各个类别,一个最简单的加 margin 的方法是

上述 loss 称作 naive additive margin loss(NAML),它在类别两两之间加上了相同的边际 m,强迫不同类的样本之间分开一定的距离。这种简单加上等距离边际的方法在小样本测试集上(比如相似度很高的类别上)可能会带来错误。为了进一步精细化设计边际,我们借助类别之间的语义相似度,来自适应地生成边际。

在介绍自适应边际前,我们首先描述如何来衡量两个类别之间的语义相似度。具体来说,我们首先把各个类别的名称(比如“狗”)输入到一个预训练语言模型(比如 Glove)中,得到每个类别对应的语义向量(词向量),然后通过一个相似度度量模块(比如 cosine similarity)就可以计算类别两两之间的语义相似度了。根据类别之间的语义相似度,我们自适应地生成边际并加入到损失函数中。

自适应边际生成流程图

3.1 类别相关的边际损失(CRAML)

对于两个类别 ij ,首先得到它们的语义向量 e_ie_j 。然后我们通过线性模型 \mathcal M 来生成它们的边际,即

其中,\alpha\beta 是要学习的参数。于是,我们将损失函数改写为

通过合适地引入语义信息,CRAML 可以把相似的类别在特征空间中分的更开,从而帮助更好地识别新类的样本。

3.2 任务相关的边际损失(TRAML)

到目前为止,我们都只考虑边际与任务无关。如果每次只考虑一个 meta-training task 中涉及的类别,那么可以更加精细地生成适合的边际。通过将一个 meta-training task 中的每个类与该 meta-training task 中其他类一一比对,我们可以衡量一个 task 内“相对的”语义相似度,从而生成适合这个 task 的边际。

任务相关的边际生成示意图

具体来说,对于一个 meta-training task 中的类 y\in C_t,我们用一个神经网络 \mathcal G 来生成 task 内的边际(见上图),即

损失函数对应地改写为

也就是说,对一个问询集样本 (x,y)\in Q(比如一张 dog 的图片),我们首先计算它和 task 内其它每个类(cabinet,wolf,sofa)的语义相似度,然后把这些语义相似度通过神经网络 \mathcal G 来生成损失函数需要的边际,最后累加到损失函数 TRAML 中。

3.3 扩展到广义小样本学习

广义小样本学习中,测试数据既有来自新类也有来自基类,因此比标准小样本学习更加挑战。我们的自适应边际设计得非常灵活,用它训练得到的 embedding 和度量模块可以直接用来预测测试样本的标签。

4、实验验证

4.1 标准小样本学习

我们在 mini-ImageNet 上进行验证,选取了 AM3 和 Prototypical Network 作为 backbone。可以看到, TRAML 的引入显著提高了两个 backbone 的分类精度,这说明我们的设计可以有效增强基于度量学习的元学习方法。同时,AM3 + TRAML 超越了 state-of-the-art 的结果。

4.2 广义小样本学习

我们在 ImageNet 2012 上进行验证,选取了 Dynamic FSL 作为 backbone。我们首先加入 TRAML 在基类上训练 embedding 模块,然后用训练得到的 embedding 模块来提取所有训练样本的特征。后续 Dynamic FSL 中用到的 weight generator 采用刚刚计算出的特征作为输入。最后,我们把训练 weight generator 的原始分类损失函数替换为 TRAML 来进行训练。可以看到,TRAML 的引入在新类和全类上的性能都超越了 baseline,同时对于不同的 shot 数也都一致地好。

4.3 Ablation Study

我们以 AM3 为 backbone,分别测试了原始分类损失,NAML,CRAML 和 TRAML 的性能。

可以看到:

  • 相比于原始分类损失,TRAML 对分类精度有了显著的提高。
  • 各类采用相同的边际(NAML)对性能的提升非常有限,说明设计自适应的边际非常重要。
  • 自适应的边际 CRAML 和 TRAML 都对精度提升明显,其中更加精细设计的 TRAML 相比于 CRAML 对精度的提升更大。

另外,我们也在实验中观察到 CRAML 中学到的系数 \alpha 是正值,这也验证了我们的直觉,即相似的类别之间需要加更大的边际来加以区分。

5、总结

本文从 CV 和 NLP 多模态的视角切入,通过考虑类别的语义信息来提升小样本学习的性能。本文提出在类别之间加入自适应的边际距离来提升基于度量学习的元学习性能,其中的边际距离是通过类别之间的语义相似度自动生成的。直观上,语义上越相似的类别之间越难区分,设定的边际距离也越大。大量的实验表明,本文的方法在标准小样本分类和广义小样本分类任务上都显著超越了现有的方法。


最后,打个小广告,我们正在招 2-3 名实习生,如果您感兴趣,欢迎点击如下链接了解:

一份实习邀约zhuanlan.zhihu.com图标

编辑于 07-25

文章被以下专栏收录