IJCAI-24,[2405.06312] FedGCS: A Generative Framework for Efficient Client Selection in Federated Learning via Gradient-based Optimization

摘要

​ 考虑问题:异构性、高能耗。因此,需要高效的客户端选择策略。传统方法的启发式方法和基于学习的方法,无法全面解决这些复杂问题。因此本文提出FedGCS,一种创新的生成式客户端选择框架,首次将客户端选择过程重新定义为一个生成任务(其实在最早说的那篇ASP选择任务里也是生成任务,只不过嵌入到了强化学习框架里;本文认为强化学习有效但没有考虑模型性能、延迟等方面)。FedGCS 受大语言模型方法的启发,通过构建连续表示空间,有效地编码了丰富的决策知识,并利用基于梯度的优化方法搜索最优的客户端选择方案,最终通过生成输出最佳结果。

该框架包括以下四个步骤:1.自动化数据收集:利用经典的客户端选择方法自动收集多样化的“选择-评分”数据对;2.训练编码器-评估器-解码器框架:基于上述数据构建连续表示空间;3.基于梯度的优化:在连续表示空间中优化以找到最佳的客户端选择;4.生成最优客户端选择:通过对训练良好的解码器使用束搜索(beam search)生成最终选择结果。

简介

image

↑↑↑强化学习不好收敛,启发式算法不全面 本文模型三个感叹号

问题建模

​ 主要是将离散客户端选择任务表述为生成任务,并在连续表示空间中进行基于梯度的优化以获得最优选择。首先在训练轮次开始时,在包含J个设备的设备池中,使用经典客户端选择算法收集 n 对“选择-评分(记作(s,p)”数据作为后续模型训练的训练数据。收集到的记录集合记为 R。对于p, 是客户端选择时的综合FL性能评分。

​ 本文评分考虑:1.全局模型的良好性能;2.较低的处理延迟;3.较低的能耗。因此本文的评价函数定义为: 是下游任务上性能,L和E是开发者指定的时延和能耗,分母上是实际的时延能耗。如果实际时延能耗超了,就要乘分数,就说明评分下降了。这些(si,pi)都是离散的。

​ 为了能有梯度,需要把这些离散数据表示到连续空间。具体来说需要学习三个模块:1.一个编码器,映射集合;2.一个评估器,基于的特征评估其得分p;3.一个解码器,把连续特征解码回。所有这些模块基于记录集合R优化。优化完后,就可以使用基于梯度的优化方法找最优选择了。这里对应着下图的1~2.

方法

框架的全图见下:

image

选择-评分数据收集

​ 用先前的启发式算法或强化学习方法完成。记作

连续空间

​ 在获得R后,需要把它映射到连续表征空间以可以梯度下降。

数据增强

​ 由于输入输出都是客户端选择集合,因此此处相关模型是seq2seq的模型。由于是集合,为了建模出这种跟顺序无关的特点,作者使用一种数据增强的手段:其实就是随机shuffle一个选择中的设备ID,以获得一个新的order。

编码器

​ 需要得到连续特征表示。作者采用单层的LSTM作为编码器,依次输入设备ID,得到输出

解码器

​ 即用生成设备Id。作者依旧采用单层LSTM,并用自回归的方式生成。在解码过程的第i步,可以得到LSTM内对应的隐藏状态;用点乘注意力可以集成的上下文信息,最终得到融合了信息的特征嵌入: ​ a就是注意力权重。之后,把两个h拼接起来,经过一个全连接层和一个softmax,就可以得到第i步的预测的分布P。把每一步的分布P相乘,就可以得到整个客户端分布。

​ 为了让生成的序列和真实的序列相似自然要设计损失函数。为此,在这一步最大化分布P的对数似然,即

评估器

​ 给定一个选择s序列,需要基于估计得分p。为此,先把上述得到的各个做一个平均以聚合特征信息,然后把这个平均特征送到评估器(就是一个前向网络)里预测分数p。这个损失就用MSE计算。

​ 整体损失就是这两个做一个加权和就OK了。

基于梯度的选择优化

​ 通过上面的过程,可以得到训练后的编码-解码器和评估器。之后,就可以用梯度下降的方法去寻找最优的客户端选择。为了合理正确的初始化模型,先选择最高的K个<s,p>对并编码成连续空间里的表示。

​ 假设其中一个编码结果是,为了得到评分更好的s序列,就将朝着梯度方向优化:。对所有的编码结果都执行这一操作,就得到了候选选择表示集合;再用评估器进行评分,就可以选择出最优的特征。之后,采用解码器解码出最终的选择。