优化策略5 Label Smoothing Regularization_LSR原理分析

优化策略5 Label Smoothing Regularization_LSR原理分析

Label Smoothing Regularization(LSR)是一种通过在输出y中添加噪声,实现对模型进行约束,降低模型过拟合(overfitting)程度的一种约束方法(regularization methed)。

获取最新消息链接:获取最新消息快速通道 - lqfarmer的博客 - 博客频道 - CSDN.NET

1、背景

假设有一个分类模型,预测观测样本x属于K个类别的概率。对于观测样本x,采用P(y’|x)表示模型对x的预测的概率分布,q(y|x)表示模型输出y的真实分布。分类器采用交叉熵做为目标函数,使模型预测的概率分布P(y’|x)尽量接近真实的概率分布q(y|x)。在进行模型训练时,通常采用类似one-hot-to-vec的0或1的方式,对真实分布q(y|x)进行编码,即观测样本属于某些类别,则对应类别的P(y|x)的值为1,否则为0。这样的编码方式存在两个明显的问题:

(1)、可能导致过拟合。0或1的标记方式导致模型概率估计值为1,或接近于1,这样的编码方式不够soft,容易导致过拟合。为什么?用于训练模型的training set通常是很有限的,往往不能覆盖所有的情况,特别是在训练样本比较少的情况下更为明显。以神经机器翻译(NMT)为例:假设预测句子“今天下午我们去..”中,“去”后面的一个词。假设只有“去钓鱼”和“去逛街”两种搭配,且真实的句子是“今天下午我们去钓鱼”。training set中,“去钓鱼”这个搭配出现10次,“去逛街”搭配出现40次。“去钓鱼”出现概率真实概率是20%,“去逛街”出现的真实概率是80%。因为采用0或1的表示方式,随着training次数增加,模型逐渐倾向于“去逛街”这个搭配,使这个搭配预测概率为100%或接近于100%,“去钓鱼”这个搭配逐渐被忽略。

(2)、模型become too confident about its predictions。情况与过拟合比较类似,模型对它的预测过于confident,导致模型对观测变量x的预测严重偏离真实的情况,比如上述例子中,把“去逛街”搭配出现的概率从80%放大到100%,这种放大是不合理的。

获取最新消息链接:获取最新消息快速通道 - lqfarmer的博客 - 博客频道 - CSDN.NET

2、Label Smoothing的原理及解释

Label Smoothing Regularization(LSR)就是为了缓解由label不够soft而容易导致过拟合的问题,使模型对预测less confident。LSR的方法原理:

假设q(y|x)表示label y的真实分布;u(y)表示一个关于label y,且独立于观测样本x(与x无关)的固定且已知的分布,通过下面公式(1)重写label y的分布q(y|x):

q’(y|x)=(1 - e)* q(y|x)+ e * u(y) (1)

其中,e属于[0,1]。把label y的真实分布q(y|x)与固定的分布u(y)按照1-e和e的权重混合在一起,构成一个新的分布。这相当于对label y中加入噪声,y值有e的概率来自于分布u(k)。为方便计算,u(y)一般服从简单的均匀分布,则u(y)=1/K,K表示模型预测类别数目。因此,公式(1)表示成公式(2)所示:

q’(y|x)= (1 - e)* q(y|x) + e / K (2)

注意,LSR可以防止模型把预测值过度集中在概率较大类别上,把一些概率分到其他概率较小类别上。

从交叉熵的角度,可以得到关于LSR的另一个解释。引入噪声分布u(k)之后,模型的交叉熵loss公式变为公式(3)所示。

H(q’,p)=-E(1...K)(log(p(k))*q’(k))

=(1 - e)* H(q,p)+ e*H(u,p) (3)

其中,E(1...K)表示从1至K的累加求和。因此,LSR相当于采用两个losses,即H(q,p)和H(u,p)来代替原始单一的交叉熵损失函数H(q,p)。u(k)是往label中加入的、已知的先验分布,按照e / (1 - e)的概率来偏移(deviation)预测分布p。这种偏移(deviation)可以通过KL距离来获得,H(u,p)= D(KL)(u||p)+H(u),其中,H(u)是已知且固定的。当u服从均匀分布时,H(u,p)衡量预测分布p与均匀分布u的不相似程度。

获取最新消息链接:获取最新消息快速通道 - lqfarmer的博客 - 博客频道 - CSDN.NET

往期内容推荐:

纯干货-8 21套深度学习相关的视频教程分享
模型汇总17 基于Depthwise Separable Convolutions的Seq2Seq模型_SliceNet原理解析
模型汇总16 各类Seq2Seq模型对比及《Attention Is All You Need》中技术详解
<模型汇总-6>堆叠自动编码器Stacked_AutoEncoder-SAE
深度学习模型-13 迁移学习(Transfer Learning)技术概述
发布于 2017-07-03

文章被以下专栏收录

    专注深度学习、NLP相关技术、资讯,追求纯粹的技术,享受学习、分享的快乐。欢迎扫描头像二维码或者微信搜索“深度学习与NLP”公众号添加关注,获得更多深度学习与NLP方面的经典论文、实践经验和最新消息。