谈谈神经网络里的Lasso

Lasso是一个常见的统计方法,常用于feature selection。但是即便是最简单的线性模型,Lasso的使用往往也有很多坑会被踩(可能大部分本文读者并不了解),这篇文章就简单谈谈怎么把Lasso用到神经网络模型的压缩这个领域来。

什么样的Lasso问题是合理的?

做任何模型都有一个与现实世界相关联的目的,而formulate一个优化问题并求解严格来说并不能算满足这样的目的。为什么呢?让我们先看一个故事:

Y同学今年本来打算去阿拉斯加玩耍,要定一个Arctic Circle游玩的day tour,这个tour从Fairbanks出发包含一段做小飞机达到Arctic Circle的旅程,但是她十分担心这个小飞机的安全性。于是作为一个有量化背景出身的妹子,她从NOAA的网站下载了过往20年所有小飞机出事故的数据,希望能通过线性建模来计算如果坐这种小飞机对生命安全的风险。

假设她下载的数据有很多特征,比如事故发生当天的日期与天气,运营飞机的航空公司,飞机的执勤年数,驾驶员的飞行里程数,飞机运载的人数,飞行的里程数等等。作为第一步,她想知道哪些特征会跟事故的安全性有关,于是她决定使用Lasso模型来做feature selection。她先将数据都通过特征工程处理成连续的特征,然后打算求解一个带Lasso的线性模型。这时她发现一个有趣的问题:

Lasso模型选取的特征不仅和她选取的penalty有关,而且跟她如何scaling数据各个维度的特征也是密切相关的。

她就疑惑了,原则上来说数据各个特征维度彼此是刻画不相关的方面,它们的数值scale不应该决定最后的特征选择。于是她就去请教S大的统计大佬Trevor Hastie,Prof Hastie告诉她你没有好好学习我的《统计学习基础》这本书哦?

在使用Lasso模型之前,应该先把每个特征维度都normalize,否则你得到特征选择的结果是没法解释的。

Y同学回去后仔细研读了相关的文献,发现Lasso的使用是有recovery condition,准确的说就是假设数据满足有一个low rank的true model,那么noisy features和true features必须有合适的covariance结构,Lasso才能有效的recover true model。一个经典的结果是Peng Zhao读博期间写的文章 On Model Selection Consistency of Lasso 提供的。(对特征选取如此精通的Peng Zhao博士毕业后从事高频交易的策略研发,可以说以一己之力改变了这块HFT的格局:Citadel Securities 新任 CEO Peng Zhao 是怎样的人?

实际上在工程领域,Lasso和相关模型被大面积使用,使用L1 norm的优化来enforce稀疏性已经成了Deep Learning时代以前的一大灌水利器。但是很多人都忘记了,获得稀疏性并不是目的,而是手段;真正的目的是调选重要的特征!

凸问题里的Lasso优化

抛开Lasso问题的formulation是否跟现实世界相符合,优化领域对这类问题也是非常感兴趣的,一个经典的方向就是如果开发高效方便的优化方法。

这个时候不得不提Iterative Shrinkage Thresholding Algorithm (简称ISTA),这个方法是如此简单优美,让人爱不释手。它的算法只有一行,再对参数做一次梯度迭代后会再多调用一步:

x := \max\{|x| - \lambda \mu, 0\} \cdot \text{sign}(x)

其中 \lambda 是penalty, \mu 是learning rate。一个经典的结果是,如果优化的原问题是strongly convex的话,那么它的Lasso变体在ISTA迭代下是线性收敛的。

非凸优化里的Lasso和神经网络模型的压缩

到了非凸问题的领域,优化就变得很棘手了。一个最近比较火的案例就是如何让神经网络模型变得更加稀疏,从而减少一个网络在inference时候的计算量,从而方便deploy到production。这类问题这几年做的人非常多,而一个常见的套路就是find somewhere sparse。不得不说神经网络是如此复杂,以至于总是可以找到地方来enforce sparsity,只要对想稀疏化的部分加上regularization就行了。

这时候我就不得不再次强调,片面追求稀疏并不是目的,特征选择或者说区分出重要和不重要的计算部件才是目的。尤其是在network pruning这个方向,这点就显地格外重要,一个简单的线性模型尚且需要数据需要事先被normalize,那么在神经网络中随机的加regularization本身就是questionable的。

我们最近ICLR 2018的工作看到了这个问题,本着稀疏不是目的,(尽量)合理使用Lasso的原则,提出了对Batch Normalization中的scaling parameter使用ISTA优化的方法来实现CNN channel pruning。我们在经典的ImageNet问题上获得了比较可观的结果,成功把ResNet101压缩到原来的40%规模,损失了不到两个百分点的error。

由于是我实习时候的工作,代码版权归公司原则上无法分享。文章发表后,我收到了不少人的邮件询问我的实现细节希望重现我的实验结果,如果你也有疑问可以给我发信。



参考文献:

code: bobye/batchnorm_prune

Rethinking the Smaller-Norm-Less-Informative Assumption in Channel Pruning of Convolution Layers, Jianbo Ye, Xin Lu, Zhe Lin, and James Z. Wang, ICLR 2018

编辑于 2018-10-16