首发于FL-Graph
论文笔记:NeurIPS'21 Federated Graph Classification over Non-IID Graphs (GCFL)

论文笔记:NeurIPS'21 Federated Graph Classification over Non-IID Graphs (GCFL)

前言

基于联邦图机器学习,本文分析了来自于不同领域的真实图,以证实分布式存储的图确实具有某些图的属性,这些属性与随机图相比具有统计学意义。然而,不同的分布式存储的图数据,即使来自同一领域或同一数据集,在图结构和节点特征方面都是 no-IID(Feature distribution skew, Label distribution skew 和 Quantity skew)。

基于此,本文提出了一个图聚类联合学习(graph clustered federated learning,GCFL)框架,该框架基于 GNN 的梯度动态地找到局部系统的簇,并从理论上证明这种簇可以减少局部系统所拥有的图之间的结构和特征异质性。此外 GNN 的梯度在 GCFL 中是相当波动的,这阻碍了高质量的聚类,因此提出一个基于梯度序列的动态时间扭曲的聚类机制(GCFL+)。

如果大家对大图数据上高效可扩展的 GNN 和基于图的隐私计算感兴趣,欢迎关注我的 Github,之后会不断更新相关的论文和代码的学习笔记。

1. Motivation

现实世界中的图通常保留了许多共同的属性,但是来自异质信息源(如不同的数据集或甚至不同的领域)的现实图是否能在彼此之间提供有用的共同信息?基于此,作者设计如下实验:实验所分析的图属性如下:kurtosis of degree distribution(如果值较大代表节点度的长尾分布)、similar average shortest path lengths、largest component size、clustering coefficient(图中的节点之间结集成团的程度的系数)

p-value(在假设原假设正确的情况下获得至少与观察到的结果一样极端的检验结果的概率)

作者分析了来自于不同领域的数据集,PTC_MR(分子图)、ENZYMES(蛋白质)、IMDB-BINARY(社交网络)、MSRC_21(超像素)。基于表 1,发现它们确实具有某些属性,相比与具有相同数量节点和边的随机图,这些属性在统计学上更为显著。这可以在很大程度上影响图数据挖掘模型,并使跨数据集甚至领域的基于图分类的联邦学习成为可能。

虽然跨领域的图数据集之间存在共同的模式属性,但仍然存在异质性。事实上,由于各种原因,详细的图结构分布和节点特征分布都可能出现异质性。在跨数据集的联邦学习中,将可能具有显著异质性的图称为 no-IID 图,涉及到结构 no-IID 和特征 no-IID,在这种情况下,FedAvg 可能产生较差的效果。因此需要一个动态的 FL 算法来跟踪 no-IID 图的异质性,同时完成多个本地客户端的协同模型训练。

每个客户端拥有本地私有的图数据,客户端之间图数据的相似性的度量是解决该问题的一个可行方案,因此本文提出了一个 graph-level 的聚类 FL 框架(GCFL),通过将强大的图神经网络(GNN)如 GIN 整合到聚类 FL 中,服务器可以根据 GNN 的梯度动态地对客户进行聚类,而无需额外的先验知识,同时根据需要协同训练多个 GNN 以实现客户的同质化聚类。从理论上分析,GNN 的模型参数反映了图的结构和特征,因此使用 GNN 的梯度进行聚类原则上可以产生结构和特征异质性都降低的簇。(详细证明可见附录 A)

尽管 GCFL 在理论上可以得到同质化的簇,但在其训练过程中,每一轮通信中传输的梯度波动很大,可能是因为客户端之间存在在结构和特征方面的异质性复杂交互,使得局部梯度朝着不同的方向变化。在 GCFL 框架中,服务器只根据最后传输的梯度来计算聚类的矩阵。因此,进一步提出了一个基于梯度序列的 GCFL 改进版(称为GCFL+)。

2. Method

2.1 The FedAvg algorithm

FedAvg 是一种基于 SGD 的聚合算法。FedAvg 是第一个基本的 FL 算法,通常被用来作为更高级的 FL 框架设计的起点。FedAvg 的关键思想是将本地客户端传输的最新模型参数汇总,然后将平均参数重新分配给每个客户。

假设总共有 m 个客户端,在每一个通信回合 t ,服务器首先采样部分客户端 \{\mathbb{S}_i\}^{(t)} ,对于采样客户端集合的每一个 \mathbb{S}_i ,用自己的数据 \mathcal{D}_i 对从服务器下载的模型进行本地训练 E_{local} 个 epoch,之后客户端 \mathbb{S}_i 上传更新后的模型参数 w_i^{(t)} 给服务器,服务器通过聚合机制:

w^{(t+1)} = \sum_{i=1}^m\frac{|D_i|}{|D|}w_i^{(t)}\tag{1}

进行参数更新,其中 |D_i| 代表客户端 \mathbb{S}_i 数据样本大小, |D| 代表所有客户端的数据。在生成聚合参数(全局模型更新)后,服务器将新的参数 w^{(t+1)} 广播给远程客户端,在 (t+1) 轮客户端使用 w^{(t+1)} 开始本地训练,再进行 E_{local} 个 epoch。

2.2 Non-IID structures and features across clients

  • 对于结构异质性,本文使用匿名随机游走(AWE)为每个图生成一个表示,并计算每对图的AWE 之间的 Jensen-Shannon distance;
  • 对于特征异质性,计算图中所有连接节点对之间的特征相似度的经验分布,并计算每对图的特征相似度分布之间的 Jensen-Shannon divergence。

如表 2,在单一数据集、单一领域和不同领域中,图的结构和特征都表现出不同程度的异质性。本文把具有这种结构和特征异质性的图称为 no-IID 图。将 FedAvg 等 简单的 FL 算法应用于拥有 no-IID 图的客户端,可能产生糟糕的效果。具体来说,结构的异质性使得模型很难捕捉到不同客户端之间的图结构模式,而特征的异质性使得一个模型很难学习到不同客户端之间的消息传播函数。

2.3 GCFL

为了解决上述问题,本文提出 Graph Clustered Federated Learning (GCFL)。主要思想是寻找具有类似结构和特征的图的客户端簇,并在同一簇的客户端中用 FedAvg 训练图模型。

考虑一个具有 n 个客户端的集合 \{\mathbb{S}_1,\mathbb{S}_1\dots,\mathbb{S}_n\} ,服务器可以动态地将客户端聚类成一个大小为 m 的簇的集合 \{\mathbb{C}_1,\mathbb{C}_2.\dots\} ,由于本文是 graph-level,因此每个客户端 \mathbb{S}_i 拥有本地图集合 \mathcal{G}_i = \{G_1,G_2,\dots\} ,图表示为 G_j = (V_j,E_j,X_j,y_j)\in\mathcal{G}_i ,每个客户端 \mathbb{S}_i 的任务是通过预测类别标签来完成 graph-level 的图分类任务,形式化的表示形式为 \hat{y}_j = h_k^*(G_j) ,其中 G_j \in \mathcal{G}_i h_k^*\mathbb{S}_i 所属集群 \mathbb{C}_k 的协同学习的最优图模型。算法框架最小化对于每一个簇类 \{\mathbb{C}_k\} 的损失函数 F(\Theta_k):=E_{\mathbb{S}_i\in\mathbb{C}_k}[f(\theta_{k,i};\mathcal{G}_i)] ,同时在模型训练过程中不同动态得到最优的簇类中心 \Gamma(\mathbb{S}_i)\rightarrow\{\mathbb{C}_k\}

GCFL 框架通过利用客户端的传输梯度 \{\Delta \theta_i\}_{i=1}^n 来动态地对客户端进行聚类,以最大限度地提高更多同质客户之间的协作,消除异质客户对协同训练模型的影响。如果客户端的数据分布是高度异质的,客户端的协同训练不能共同优化局部损失函数。在这种情况下,经过几轮通信后,一般 FL 将接近静止点,而且客户传输梯度的规范不会全部趋于零。因此,在一般 FL 接近静止点的时候,需要对客户进行聚类。引入一个超参数 \epsilon_1 (是否需要根据现有的簇重新划分本地客户端)作为标准,根据是否接近静止点来决定是否停止一般的 FL:

\delta_{mean} = ||\sum_{i\in\ [n]}\frac{\mathcal{G}_i}{\mathcal{G}}\Delta \theta_i||<\epsilon_1\tag{2}

同时,如果有一些客户端仍然存在较大的传输梯度,这意味着簇中的现有客户端之间是高度异质的,因此需要进行聚类来消除客户端之间的负面影响。基于此引入 \epsilon_2 (是否需要基于现有的簇生成新的簇实现更好的同质性聚类)来分割现有的簇:

\delta_{max} = max(||\Delta \theta_i||)>\epsilon_2>0\tag{3}

GCFL 遵循自上而下的双向分配机制,在每个通信轮 t ,服务器端接收到 m 个梯度的集合 \{\{\Delta\theta_{i_1}\},\{\Delta\theta_{i_2}\},\dots,\{\Delta\theta_{i_m}\}\} ,其中传输梯度来自于包含了多个客户端的不同的簇 \{\mathbb{C}_1,\mathbb{C}_2,\dots,\mathbb{C}_m\} 。对于聚类簇 \mathbb{C}_k ,如果 \delta_{mean}^k,\delta_{max}^k 满足式(2)和式(3),服务器将计算出一个簇内的余弦相似度矩阵 \alpha_k ,基于此建立一个全连接图,节点为簇内的所有客户端,之后 Stoer–Wagner minimum 分割算法被用于所构造的全连接图,对图进行双分区,并将簇划分为 \mathbb{C}_k\rightarrow\{\mathbb{C}_{k1},\mathbb{C}_{k2}\} 。基于式(2)和式(3)可以自动、动态地完成聚类。

簇类中心 \mathbb{C}_k 的客户端 \mathbb{S}_i ,尝试寻找 \hat{\theta}_{k,i} 来接近真实解 \theta^*_{k,i} = \arg\min_{\theta_i\in\Theta_k}f(\theta_{k,i};\mathcal{G}_i) ,每一个通信轮 t ,客户端将其梯度传输给服务器:

\Delta \theta_{k,i}^t = \hat{\theta}_{k,i}^t - \theta_{k,i}^{t-1}\tag{4}

由于服务器维护簇的分配,可以通过以下方式将梯度按簇进行汇总:

\theta_k^{t+1} = \theta_k^i + \sum_{i\in[n_k]}\Delta \theta_{k,i}^t\tag{5}

2.4 Theoretical analysis

详细证明可见附件 B,本文使用 Bourgain theorem 来约束以不同图结构/特征生成的嵌入之间的差异,并证明图的特征和结构信息被纳入模型权重(梯度)。通过证明模型权重(梯度)与结构/特征的差异是有界限的,表明梯度将随着结构和特征的变化而变化。这进一步证明了 GCFL 能够捕获结构和特征信息。此外,基于以下问题,GCFL 框架在未来可以进一步扩展到跨任务的 graph-level 联合学习。

2.5 GCFL+: improved GCFL based on observation sequences of gradients

如图 1 所示,展示了 GCFL 中每一轮通信的传输梯度:

  1. 传输梯度持续波动;
  2. 不同客户端的传输梯度有不同的尺度。

梯度规范的波动和不同的尺度表明,客户端的梯度更新方向和步长是不同的,体现了结构和特征的异质性。在 GCFL 框架中,一旦聚类标准得到满足,服务器就会根据最后传输的梯度计算余弦相似度矩阵。然而,根据观察,尽管有聚类标准的约束,梯度的规范在通信回合中是波动的,基于梯度点的 GCFL 聚类可能会遗漏重要的客户行为并被噪音误导。GCFL 在第 119 轮根据该轮的梯度进行聚类,这并不能有效地找到异质性较低的图。

基于此提出改进的 GCFL(GCFL+),它通过考虑传输梯度的序列来进行聚类。在 GCFL+ 框架中,服务器维护一个多变量时间序列矩阵 Q\in\mathbb{R}^{\{n,d\}} ,其中 n 是客户端的数量, d 是被追踪的梯度序列的长度。在每个通信回合 t ,服务器通过向 Q(i,:)\in\mathbb{R}^d 添加传输梯度 ||\Delta \theta_i^t|| 来更新 Q ,并删除过时的传输梯度。GCFL+ 使用与 GCFL 相同的聚类标准(式(2)和式(3))。如果聚类标准得到满足,服务器将计算出一个距离矩阵,其中每个元素是两个序列梯度的成对距离。本文使用一种叫做动态时间扭曲(DTW)的技术来衡量两个数据序列之间的相似性。对于一个集群 \mathbb{C}_k ,服务器计算其距离矩阵为

\beta_k(p,q) = dist(Q(p,:),Q(q,:)),\;\;p,q\in idx(\{\mathbb{S}_i\})\tag{6}

其中 idx(\{\mathbb{S}_i\} 代表在簇 \mathbb{C}_k 中所有本地客户端 \{\mathbb{S}_i\} 的索引。有了距离矩阵,服务器可以对满足聚类标准的集群进行双分区。因此,在图1(b)中,GCFL+ 在第 118 轮基于长度为 10 的梯度序列进行聚类,这就抓住了客户的长距离行为,有效地提高了集群的同质性。

3. Experiments

实验细节可见 Experimental settings

Federated graph classification within single datasets

Federated graph classification across multiple datasets

在两种情况下对多个数据集进行了实验:

单域(MOLECULES),和跨域(BIOCHEM和MIX)

Effects of hyper-parameters \epsilon_1 and \epsilon_2

超参数 \epsilon_1 是一个停止标准,用于检查当前客户端集合上的一般 FL 是否接近静止点。理论上, \epsilon_1 应该被设置得越小越好。超参数 \epsilon_2 更依赖于客户端的数量和客户端之间的异质性。较小的 \epsilon_2 会使客户更有可能被集中起来。当 \epsilon_1\epsilon_2 在可行的范围内时,它们的微小变化对性能的影响不大,因为聚类的结果基本保持不变。当 \epsilon_2 设置得过大时,其性能将与直接应用基本的 FL 算法相似(即有一个集群)。当 \epsilon_2 设置得太小时,将产生更多的小规模集群,甚至是单一客户。

Structure and feature analysis in clusters

数据集的的来源信息可以在一定程度上验证集群的合理性,但不能完全依靠这种先验知识来确定最佳的集群,GCFL 具有沿 FL 过程的性能驱动的动态聚类能力。

Convergence analysis

将测试损失与通信轮次的关系可视化,图 4 显示了两种设置下的训练曲线,说明 GCFL 和GCFL+ 实现了与 FedProx 相似的收敛率,FedProx 是处理 no-IID 欧氏数据的最先进的 FL 框架。与 FedAvg 相比,GCFL、GCFL+ 和 FedProx 都能收敛到较低的损失。

编辑于 2021-12-16 09:56