《开放世界半监督学习》阅读笔记

《开放世界半监督学习》阅读笔记

题目:《Open-World Semi-Supervised Learning》

Abstract

传统上,监督和半监督学习方法被设计用于一种封闭世界的设置,也就是说,基于一个未标记的数据只包含在已标记的训练数据中遇到的类的假设。 然而,真实的世界本质上是开放和动态的,因此新颖的、以前看不到的类可能出现在测试数据中或者模型部署期间。本文引入了一个新颖的开放世界半监督学习设置,其中模型需要识别以前见过的类,以及发现在标记数据集中从未见过的新类

为了解决这个问题,本文提出了ORCA,一种学习同时分类和聚类数据的方法。ORCA将未标记数据集中的例子分类到以前见过的类中,或者通过将相似的例子组合在一起形成一个新的类。ORCA的关键思想是引入基于不确定性的自适应margin,有效地规避由可见类和新类/簇之间的方差不平衡引起的偏差。本文使用图像分类领域的三个常用数据集(CIFAR-10, CIFAR-100,ImageNet) 进行实验验证,结果表明,ORCA在已知类上的性能优于半监督方法,在新类上也优于新类发现方法。

Introduction

目前图像分类,文本分类这些任务,主要是通过获取大量已知类的标注数据,然后用神经网络学习到任务相关的知识表示。这种方式主要基于封闭世界假设,但是现实世界往往是开放的,期望一个人能够提前识别和预先标记所有类别/类,并手动监督机器学习模型通常是不现实的或昂贵的。

为了解决这种开放世界的问题,目前有2种思路:(1) OOD检测:能够识别已知类的数据,并且能够将所有未知类的数据检测出来,标为"unknown"。这种方法很好的保证了系统鲁棒性,但是无法充分利用未知类数据进行业务扩展;(2) novel class discovery(零样本,领域自适应问题): 利用源域标记数据来学习更丰富的语义表示,然后将学到的知识迁移到目标域(包含新类别),对目标域数据进行聚类。这种方法不能准确识别出已知类,只是对目标域做了聚类。

本文提出一种开放世界半监督学习设置,在这种设置下,未标记的数据集可能包含从未在标记集中出现过的类,并且模型需要能够:(1)识别未标记数据的样本是否属于标记数据集中出现的已知类之一,以及(2)通过有效地将来自未标记数据的相似示例分组并将其分配给新的类/聚类,在没有任何先前知识的情况下自动发现新的/不可见的类。

因此,在开放世界的SSL设置中,对于每个未标记的示例,模型需要决定是将其分类到标记数据集中以前看到的类之一,还是将其分配到新的类。这意味着模型需要联合解决分类和聚类任务

为了解决开放世界SSL的挑战,作者提出了ORCA(基于不确定性的自适应margin),这种方法可以有效地将未标记数据中的例子分配给以前看到的类,或者通过在端到端深度学习框架中分组相似的例子来形成新的类/簇。

Contributions

  • 提出了开放世界半监督学习(区别于传统半监督学习,新类别发现以及OOD检测)
  • 提出了ORCA(基于不确定性的自适应margin)这种方法,通过利用标注数据和未标注数据优化一个联合损失函数,联合解决分类和聚类任务。关键思想在于在监督目标中引入基于不确定性的自适应margin。通过这种方式,ORCA减小了已知类和新类的类内方差之间的差距,提高了在未标记数据集中生成的伪标签的质量。

Method

使用标记和未标记的数据,ORCA学习一个联合嵌入函数和一个线性分类器,该分类器由可见的分类头和预期数量的新类的附加分类头组成。

  1. 预训练:本文做的是图像领域的任务,采用对比学习SimCLR进行预训练
  2. 下游任务的输出层(分类器):已知类的分类头用于将未标记的例子分配给已知类,而激活附加的分类头允许ORCA发现新类别。我们假设新类的数量是已知的,并将其作为算法的输入,这是聚类和新类发现方法的典型假设。如果不知道新类的数量,这在现实环境中是经常发生的情况,可以从数据中估计出来。在这种情况下,如果头的数量太多,那么ORCA将不会分配任何例子给一些头,所以这些头将永远不会激活,因此ORCA将自动修剪类的数量。我们在实验中进一步解决了这个问题。
  3. 损失函数:在ORCA框架下,提出了联合解决监督分类和无监督聚类任务的目标函数。ORCA中使用的目标函数结合了(I)监督目标,(ii)成对目标和(iii)均匀分布的正则化。

3.1 监督目标

这是标准的交叉熵损失,在标记数据上使用这个损失函数会在标记和未标记数据之间产生不平衡问题,即,对于已知类别更新梯度,而对于新类别则不更新梯度。这会导致整个模型偏向已知类。为了克服这个问题,在ORCA中我们引入了基于不确定性的自适应margin。

模型偏见主要是由于,在已知类标记数据上训练出来的模型在已知类和新类/簇的类内方差不平衡,导致容易出错的伪标签。为了减轻这种偏差,作者建议使用自适应margin来减小已知类和新类的类内方差之间的差距。

直观地说,在训练的开始,我们希望强制一个较大的负margin,以鼓励已知类相对于新类的类似的大的类内方差。在训练接近结束时,当已经为新类形成聚类时,我们将边缘项调整为接近0,以便模型可以充分利用有用的标签信息。

作者使用根据softmax的输出计算的未标记样本的置信度,估计不确定性来捕获类内方差。

3.2 成对目标

我们将聚类学习问题转化为成对相似性预测任务。 给定标记数据集Xl和未标记数据集Xu,我们旨在微调我们的嵌入函数fθ,并学习由线性分类器W参数化的相似性预测函数,使得来自同一类的样本被分组在一起。

为了实现这一点,我们依赖于来自标记集的ground truth标签和在未标记集上生成的伪标签。具体来说,对于有标签的集合,我们已经知道哪些对应该属于同一个类,因此我们可以使用ground truth标签。为了获得未标记集的伪标记,我们在一个小批量中计算所有特征表示对之间的余弦距离。然后,我们对距离进行排序,并为每个样本生成其最相似邻居的伪标签。

3.3 均匀分布正则化

通过在未标记数据上仅使用成对目标,ORCA可能退化为将所有样本分配到同一类的平凡解(变成一个OOD检测问题),即|Cu| = 1。为了避免这个问题,我们引入了KL散度项,正则化Pr(y|x ∈ Dl∪ Dl)使其接近均匀分布U

Experiments

4.1 实验设置

数据集:CIFAR-10,CIFAR-100,ImageNet

评价指标:已知类使用ACC,未知类使用NMI

baseline设置:作者将本文提出的ORCA方法与传统SSL方法以及新类别发现方法比较,分别看ORCA在已知类和新类上的表现。

4.2 实验结果分析

主实验结果

新类别数量的影响

估计新类别数量

消融实验(评估自适应margin的效果)

Conclusion

本文引入了开放世界半监督学习(SSL)设置,其中方法需要能够识别以前在标记数据集中遇到的已知类,以及发现新的、从未见过的类。为了解决这个问题,作者提出了ORCA,这是一种开放世界的SSL方法,可以有效地用基于不确定性的自适应margin来权衡类内方差。

实验表明ORCA在识别可见类的任务上明显优于SSL基线,在聚类不可见类的任务上优于新类发现基线。ORCA是一种独特的方法,它在端到端框架中联合解决了开放世界SSL的分类和聚类这两个任务。相关技术也可以应用于视觉以外的领域。

发布于 2021-05-22 17:29