Chebyshev多项式作为GCN卷积核

Chebyshev多项式作为GCN卷积核

利用Chebyshev多项式拟合卷积核是GCN论文中广泛应用的方法在这篇文章中,我会推导相应的公式,并举一个具体的栗子。

在之前的回答中(如何理解 Graph Convolutional Network(GCN)?),已经推导出了如下GCN的形式:

y=\sigma \left(U  g_\theta(\Lambda)  U^T x \right) \tag{1}

其中, U 是由拉普拉斯矩阵的特征向量构建的矩阵, \Lambda 是特征值构成的对角矩阵, g_\theta(\Lambda) 是卷积核, x 是输入特征, \sigma(\cdot) 是激活函数。

利用Chebyshev多项式代替卷积核,就可以得到下式:

g_\theta(\Lambda)=\sum_{k=0}^{K-1}{\beta_k T_k (\tilde{\Lambda})} \tag{2}

其中,T_k(\cdot)k 阶的Chebyshev多项式,\beta_k 是对应的系数(也就是训练中迭代更新的参数)。 \tilde{\Lambda} 是re-scaled的特征值对角矩阵,进行这个shift变换的原因是Chebyshev多项式的输入要在 \left[ -1,1\right] 之间。

这里大家可能会疑问为什么会有 \left[ -1,1\right] 这个区间的限制?因为第一类Chebyshev多项式的解析形式是(详细内容可以参考第50期 切比雪夫多项式及其应用):

T_k(x)=\cos(k\cdot \arccos(x)) \tag{3}

因为 \arccos(\cdot) 函数,所以输入必须在\left[ -1,1\right] 之间。

那么现在问题来了,如何把 \Lambda 转换在上述区间?一共有两步:

  • 由于 \Lambda\geq0 (原因为拉普拉斯矩阵半正定,特征值非负),除以最大特征值 \lambda_{max} ,就转化在 \left[ 0,1\right] 区间
  • 再进行 2\times\left[ 0,1\right]-1,就实现了目标。于是就有:

\tilde{\Lambda}=2\Lambda/\lambda_{max}-I \tag{4}

把式(2)带入到式(1)中,即可得到:

y=\sigma \left(U  \sum_{k=0}^{K-1}{\beta_k T_k (\tilde{\Lambda})}  U^T x \right) \tag{5}

因为Chebyshev多项式作用在对角矩阵上,不会影响矩阵运算。那就改变一下运算顺序,先把矩阵运算放进去

y=\sigma \left(  \sum_{k=0}^{K-1}{\beta_k T_k (U\tilde{\Lambda}U^T)}   x \right) \tag{6}

因为 L=U \Lambda U^TL 表示拉普拉斯矩阵),代入可以得

y=\sigma \left(  \sum_{k=0}^{K-1}{\beta_k T_k (\tilde{L})}   x \right) \tag{7}

其中, \tilde{L}=2L/\lambda_{max}-I 。这样变换的好处在于:计算过程无需再进行特征向量分解。

(最大特征值 \lambda_{max} 可以利用幂迭代法(power iteration)求出,详细内容可以参考cnblogs.com/fahaizhong/

在实际运算过程中,可以利用Chebyshev多项式的性质,进行递推:

T_k (\tilde{L})=2\tilde{L}T_{k-1} (\tilde{L})-T_{k-2} (\tilde{L}) \tag{8}

T_{0} (\tilde{L})=I,T_{1} (\tilde{L})=\tilde{L} \tag{9}

下面来举个栗子,以下图为例

graph示意图

这里我们利用对阵型拉普拉斯矩阵, L^{sys}=I-D^{-0.5}AD^{-0.5}

\lambda_{max}\approx1.88

  • 当K=1时,卷积核为

\left[  \begin{matrix}    \beta_0 & 0 & 0 &0&0&0\\    0 & \beta_0 & 0 &0&0&0 \\    0 & 0 & \beta_0 &0&0&0\\    0 & 0 & 0 &\beta_0&0&0\\   0 & 0 & 0 &0&\beta_0&0\\    0 & 0 & 0 &0&0&\beta_0   \end{matrix}   \right]

当K=2时,卷积核为

\left[  \begin{matrix}    \beta_0+0.07\beta_1 & -0.44\beta_1 & 0 &0&-0.44\beta_1&0\\    -0.44\beta_1 & \beta_0+0.07\beta_1 & -0.44\beta_1 &0&-0.36\beta_1&0 \\    0 &- 0.44\beta_1 & \beta_0+0.07\beta_1 &-0.44\beta_1&0&0\\    0 & 0 & -0.44\beta_1 &\beta_0+0.07\beta_1&-0.36\beta_1& -0.62\beta_1\\  -0.36\beta_1 & -0.36\beta_1 & 0 &-0.36\beta_1&\beta_0+0.07\beta_1&0\\    0 & 0 & 0 &-0.62\alpha_1&0&\beta_0+0.07\beta_1   \end{matrix}   \right]

结合图的邻接关系,明显可以看出卷积核的localize特性。

通过观察我们可以发现,当K=2时,对角线上的卷积系数中  \beta_1 前的系数很小。这种方式的好处在于:顶点自身特征基本由  \beta_0 控制,  \beta_1 控制一阶邻居的特征。

代码实现可以参考(源码地址:github.com/tkipf/gcn):

def chebyshev_polynomials(adj, k):
    """Calculate Chebyshev polynomials up to order k. Return a list of sparse matrices (tuple representation)."""
    print("Calculating Chebyshev polynomials up to order {}...".format(k))

    adj_normalized = normalize_adj(adj) # D^{-1/2}AD^{1/2}
    laplacian = sp.eye(adj.shape[0]) - adj_normalized  # L = I_N - D^{-1/2}AD^{1/2}
    largest_eigval, _ = eigsh(laplacian, 1, which='LM') # \lambda_{max}
    scaled_laplacian = (2. / largest_eigval[0]) * laplacian - sp.eye(adj.shape[0]) # 2/\lambda_{max}L-I_N

    # 将切比雪夫多项式的 T_0(x) = 1和 T_1(x) = x 项加入到t_k中
    t_k = list()
    t_k.append(sp.eye(adj.shape[0])) 
    t_k.append(scaled_laplacian)
    
    # 依据公式 T_n(x) = 2xT_n(x) - T_{n-1}(x) 构造递归程序,计算T_2 -> T_k项目
    def chebyshev_recurrence(t_k_minus_one, t_k_minus_two, scaled_lap):
        s_lap = sp.csr_matrix(scaled_lap, copy=True)
        return 2 * s_lap.dot(t_k_minus_one) - t_k_minus_two

    for i in range(2, k+1):
        t_k.append(chebyshev_recurrence(t_k[-1], t_k[-2], scaled_laplacian))

    return sparse_to_tuple(t_k)

关于参数共享方式的讨论,可以参考我的另一篇文章

superbrother:解读三种经典GCN中的Parameter Sharingzhuanlan.zhihu.com图标

在我最近发表的一篇论文中:就将结合Chebyshev多项式的GCN作为基于有限检测器的路网规模交通流量估计问题(一种特殊的时空矩阵填充问题)的baseline。既原文4.2节部分的CGMC模型,目前论文可以在12月27号前免费访问。感兴趣的朋友可以阅读如下的链接。

Please wait whilst we redirect youauthors.elsevier.com

Zhang, Z., Li, M., Lin, X., & Wang, Y. (2020). Network-wide traffic flow estimation with insufficient volume detection and crowdsourcing data.Transportation Research Part C: Emerging Technologies,121, 102870.


编辑于 11-12

文章被以下专栏收录