多角度理解自然梯度

以前写过一个 如何理解自然梯度 的回答,当时主要是从约束优化和拉格朗日乘子来得到自然梯度的过程。除此之外自然梯度还有多种理解方式,这里总结一下。

Fisher 矩阵与 KL-散度 xv

自然梯度的概念和Fisher矩阵与KL-散度是密切相连的。Fisher 矩阵定义为

F= E_{p(x|\theta)} \left[\nabla_{\theta} \log p(x|\theta) \nabla_{\theta} \log p(x|\theta)^T \right]



性质1. Score 函数的期望为零 E_{p(x|\theta)} [\nabla_{\theta} \log p(x|\theta)]=0

性质2. Fisher矩阵的两种表示形式 F= -E_{p(x|\theta)} \left[\frac{\partial ^2 }{\partial \theta \partial \theta^T} \log p(x|\theta) \right]

性质3. KL-散度的局部二阶近似为 D_{KL}[p(x|\theta) \| p(x|\theta+\delta\theta)] \approx \frac{1}{2} \delta\theta^T F \delta\theta

最基本的结论是:对于两个概率分布,KL-散度衡量了两个概率分布之间的差异,Fisher信息矩阵(FIM)是KL-散度的二阶近似,实际定义了概率分布空间上局部曲率。

统计流形上的最速下降

对于欧式空间上的目标函数,最常用的方法是梯度下降

-\frac{\nabla_{\theta} L(\theta)}{\|\nabla_{\theta} L(\theta)\|} = \lim_{\epsilon\to 0} \frac{1}{\epsilon} \arg\min_{\|\delta \theta\| \leq\epsilon} L(\theta+\delta \theta)

此式含义是,梯度方向是下降速度最快的方向,即最陡峭的方向。在空间中任何一个方向,在局部范围内下降的速度都不如负梯度方向快。需要注意的是,下降速度本身是一个比值(的极限),下降速度最快不代表沿此方向下降幅度最大。

不同空间上,最速下降方向的推导是依赖于 \|\delta \theta\| 所的范数——距离度量。距离度量在这里起着核心作用,不同的度量会得到不同的最速下降方向。对于欧式范数,最速下降方向就是负梯度方向。在概率分布空间,每个参数 \theta 表示一个参数化的概率分布,分布之间的距离用KL-散度表示,于是上面的右面的优化问题表示为

\min_{D_{KL}[p_{\theta}||\theta + \delta \theta] \leq\epsilon} L(\theta+\delta \theta)

将此式写成拉格朗日乘子法的形式

\min L(\theta+\delta \theta) + \lambda \left(D_{KL}[p(\theta||p_{\theta+\delta \theta})]-\epsilon \right) \approx L(\theta) + \nabla L(\theta)^T \delta \theta+ \lambda \left(\frac{1}{2}\delta \theta^T F \delta \theta -\epsilon\right)

对右边取梯度并令梯度为零,可得 \delta\theta^* = -\frac{1}{\lambda}F^{-1}\nabla_{\theta} L(\theta) ,即最速下降方向由 F^{-1}\nabla_{\theta} L(\theta) 方向确定(相差一个常数因子,可以和学习率合并)。此方向称为自然梯度。由于Fisher矩阵表示统计流形(概率分布空间)上的局部曲率,因此这个方向实际考虑了分布参数空间上的曲率信息。KL散度和自然梯度在参数变换下保持不变。



自然梯度与二阶优化的关系

1.Fisher矩阵是对数似然函数的Hessian矩阵的期望 F= -E_{p(x|\theta)} \left[\frac{\partial ^2 }{\partial \theta \partial \theta^T} \log p(x|\theta) \right]

Fisher Information Matrix is equal to the negative expected Hessian of log likelihood.

具体的推导可以参考这里

2. 自然梯度与Gauss-Newton法

对于MSE的loss函数 L(\theta) = \frac{1}{2}\sum_i (f(x_i,\theta) - y_i)^2 ,Gauss-Newton法是牛顿法的近似。通过链式法则,上式的Hessian矩阵可以写成

\nabla^2L(\theta) = \sum_i \nabla_{\theta}f(x_i, \theta)\nabla_{\theta}f(x_i, \theta)^T + \sum_i r_i \nabla_{\theta}^2f(x_i,\theta)

其中 r_i = f(x_i,\theta)-y_i 是残量。上式的第一项就是Gauss-Newton矩阵。对于较小的残量来说,L的Hessian矩阵就可以用右边第一项来近似。因此,尽管他们的出发点是完全不同的,Gauss-Newton矩阵与Fisher矩阵相同,自然梯度下降与Gauss-Newton法一致。

梯度方向的不确定性

对于绝大部分概率分布,Fisher矩阵都无法解析计算,只能进行数值估计。给定一组数据 \{x_i\}_{i=1}^N ,Fisher矩阵可以估计如下

F = \frac{1}{N}\sum_{i=1}^N \nabla_{\theta} \log p_{\theta}(x_i)\nabla_{\theta} \log p_{\theta}(x_i)^T

此式称为经验Fisher矩阵。在mini-batch的情形,此式可以对mini-batch做移动平均来逐步近似Fisher。记 g_i = \nabla_{\theta}\log p_{\theta}(x_i) ,上式可以看成是 g_i 的协方差矩阵,即 F=cov[g] ,描述了梯度的不确定性,自然梯度方向是目标函数值下降的概率最大的方向,参考Topmoumoute online natural gradient algorithm

编辑于 2020-12-19 22:11