首发于PaperWeekly
利用Uncertainty修正Domain Adaptation中的伪标签

利用Uncertainty修正Domain Adaptation中的伪标签

论文题目: Rectifying Pseudo Label Learning via Uncertainty Estimation for Domain Adaptive Semantic Segmentation

Paper 已经被 IJCV 接收了,感谢大家的意见和帮助,真的真的真的非常感谢。

论文地址: arxiv.org/pdf/2003.0377

作者:Zhedong Zheng, Yi Yang

代码: github.com/layumi/Seg_U

What:

  • 本文研究的是领域迁移问题中错误的伪标签(Pseudo Label)的问题,探讨了如何自动设定阈值来修正这种伪标签学习。
  • 之前的伪标签往往通过人为卡阈值(Threshold) 的方式来学高置信度的伪标签(Pseudo Label),而忽略低置信度的标签。 但是这个阈值(Threshold) 往往很难决定。
  • 我们思考如何定这个阈值,其实会发现这个阈值取决于 source 和 target 有多接近,如果source域 和 target 域 完全一样,那么阈值甚至可以卡到0.999,大多数data都有很确定的label,那全部样本都可以拿来训练; 但如果source域 和 target 域 完全不一样, 那么阈值再怎么卡,都没有用,因为所有的pseudo label都是noisy的。
  • 所以在做domain adpation的问题就在于 在实际场景中,我们往往只能凭经验来 提前设置阈值 让模型来学伪标签。这是不合理的,或者卡得太高或者太低都会影响模型学习的。所以本文我们提出用不确定性来作为阈值,等同于自动学一个阈值出来。

How:

  • 其实本质思想很简单,我们利用了分割模型中的辅助分类器(本来只是用来防止梯度消失的)。结构如下图:有一个主分类器接在res5c后面,一个辅助分类器接在res4b22. (本文研究的是图像分割, 如果您研究的是图像分类,也可以对应到分类问题中InceptionNet 系列,包括NASNet也是需要浅层辅助分类器的。)
  • 我们观察到 伪标签错误的地方,往往是 两个分类器预测不同结果的地方(Prediction Variance)。如下图:
  • 所以很自然的我们对于cross-entropy loss做了一个修正:其中D{kl} 就是主分类器和辅助分类器预测结果的KL距离(也可以叫作Prediction Variance),如果差异大,则这个距离也就大,那么Lrect 对于这种不确定的样本,就不惩罚(因为pseudo label很可能是错的)。如果没有后面+D_{kl} 这一项,模型很懒,会趋向把所有pseudo label都说成是不确定的,那么Lrect就等于0了。为了避免这种情况,所以我们加了一个+D_{kl} 。
  • 我们可以杠一下,其实原来的取阈值,是基于confidence, confidence高,model就学这种pseudo label,我们这边能不能用 confidence呢? 答案是不太好。
  • 因为pseudo label训练的时候就是增加这个pseudo label的confidence。 这样还是无法决定哪些样本可信。最后实验里,我也可视化了一下,confidence 和uncertainty的对比,如下图:(我们用的uncertainty 本质是 prediction variance,和错误pseudo label的区域有更明显的overlapping )
  • 同时,我们方法对于一些segmentation里面占得面积小的类别有奇效。因为传统方法按照confidence一刀切。往往数量小的类别(自行车Bike和骑行人Rider)这种类别吃亏,因为confidence score往往只有0.7,0.8,会被0.9的硬阈值卡掉。
  • 最后我们的方法可以在 两个常用benchmark (GTA5->Cityscapes 和 SYNTHIA -> Cityscapes)上有提升。并且和其他方法是互补的关系。 在一个小数据集上也有提升,但是由于类别比较少,所以提升不是很明显 (Cityscapes->Oxford RobotCar)。


感谢看完~~ 欢迎点赞转发评论~~也可以check我的一些其他文章

郑哲东:NVIDIA/悉尼科技大学/澳洲国立大学新作解读:用GAN生成高质量行人图像,辅助行人重识别zhuanlan.zhihu.com图标郑哲东:用CNN分100,000类图像zhuanlan.zhihu.com图标郑哲东:【新无人机数据集】从 行人重识别 到 无人机目标定位zhuanlan.zhihu.com图标Pytorch有什么节省显存的小技巧?www.zhihu.com图标

本文有一处勘误,公式7的 KL Divergence 中多加了负号,目前已在arXiv上更新了。感谢 @王晋东不在家 提出的意见和帮助。

编辑于 10-13

文章被以下专栏收录