为什么安全方案写在了基本联邦学习框架前面???

2.1 2.2的逻辑关系明显不合适(综合重排一下)

2.1 系统目标和威胁模型

  1. 本作品采用的边缘计算环境下的安全多方计算联邦学习系统采用经典的Client/Server模式,服务器和能够与之通信的多个客户端。训练开始之前由服务器向每个客户端提供一个共享模型,客户端将模型在自己的数据集上进行训练。出于保护隐私和减少通信开销的目的,服务器允许客户端在本地进行训练,而不必将客户端的数据集上传至服务器。客户端将本地训练过后得到的参数反馈给服务器,由服务器对所有客户端上传的参数进行整合,再下发调整后的共享模型。以该方式进行多次迭代,直至模型收敛。具体而言,联邦学习通过最小化所有局部数据集的并集的损失风险来训练共享模型,即:<br />![](https://cdn.nlark.com/yuque/__latex/e9e2bd202c638fba39a8d2a9238cd828.svg#card=math&code=%5Cmin%20_%7B%5Ctheta%7D%20f%28%5Cboldsymbol%7B%5Ctheta%7D%29%3A%3D%5Cfrac%7B1%7D%7Bn%7D%20%5Csum_%7Bi%3D1%7D%20f_%7Bi%7D%28%5Cboldsymbol%7B%5Ctheta%7D%29%20%5Ctext%20%7B%20with%20%7D%20f_%7Bi%7D%28%5Cboldsymbol%7B%5Ctheta%7D%29%3A%3D%5Cfrac%7B1%7D%7Bm%7D%20%5Csum_%7B%5Cxi%20%5Cin%20D_%7Bi%7D%7D%20l%28%5Cboldsymbol%7B%5Ctheta%7D%2C%20%5Cxi%29&height=48&id=h5OEx)<br />其中, ![](https://cdn.nlark.com/yuque/__latex/b99655e634670e45eaf2704b7bc12155.svg#card=math&code=D_i%3D%5Cleft%5C%7B%5Cxi_1%5Ei%2C%5Cldots%2C%5Cxi_m%5Ei%5Cright%5C%7D&height=23&id=dssik)为每个客户端拥有的数据集,![](https://cdn.nlark.com/yuque/__latex/2554a2bb846cffd697389e5dc8912759.svg#card=math&code=%5Ctheta&height=16&id=GAmSo)来自其边缘设备的m个数据点的集合。 ![](https://cdn.nlark.com/yuque/__latex/3bafad5c16ec41ebba0ffe1fcf6d8d12.svg#card=math&code=%5Ctheta%5Cin%5Cmathbb%7BR%7D%5Ed&height=19&id=UvSq5)为共享模型 。在边缘计算环境下,部署的边缘计算终端将会按照程序设定进行计算任务,但是,许多应用边缘计算的具体场景如工业物联网、智慧城市等,部分终端设备自身安全防护能力薄弱、漏洞问题严重,存在着窃听攻击、控制夺权攻击,这样部分终端就从训练“参与者”变成了“攻击者”。因此,我们引入的威胁模型为“诚实且好奇”的半可信威胁模型。我们假设攻击者可以是恶意攻击者,也可以是系统中“诚实且好奇”的客户端和租用的不安全服务器。存在着可能的客户端被迫攻击掉线与客户端共谋攻击。<br /> 为此,本作品研究增强算法在训练中的鲁棒性,以解决在不可信、不稳定网络环境下可能出现多用户互相共谋、用户在训练中突然掉线等问题;并研究设计一种轻量级安全多方计算协议来加强隐私保护,以解决联邦学习虽然通过只传输参数而不传输数据减小了隐私泄露的风险,但依然难以抵抗最新出现的模型提取、模型逆向等隐私攻击的问题。<br /> 在上述背景下,本作品拟实现以下目标:<br />写这个目标时,到底书写者是从场景中出发,还是从现有知识中,为什么直接提出来了这些目标<br />(1) **取得与原始联邦学习几乎相同的准确性。**<br />联邦学习采用分布式计算方法,该方法能够让多个客户端协作开发模型,而且不需要彼此共享敏感的数据。在多次迭代过程中,共享模型所覆盖的数据范围会比任何一个组织内部拥有的数据都要大得多,因此联邦学习开发的模型具有准确性高的特点。为了使我们的系统保持联邦学习的准确性,我们在上传参数加密时拟引入diffie-hellman协议的思想。相比于差分隐私技术在参数中添加大量的随机化,使得数据的可用性下降,该方法虽然同样在参数上添加随机数,但在服务器参数整合时所有随机数加和后归零。这使得本系统在准确性上堪比原始联邦学习的准确性。<br />为此,我们拟研究如何在所有用户在线和不同数量用户掉线等情况下都能使得整合参数时所有随机数加和归零。<br />(2)**减少联邦学习算法中的通信开销。**<br />研究在不影响算法精度的同时,如何设计一种优化策略来减小通信开销,以解决联邦学习由于在训练中需要多次上传和下载梯度参数,造成GB以上的通信开销,难以适用于带宽有限的边缘计算、物联网等场景的问题。<br />(3)**保护模型参数的隐私,同时不增加太多的计算开销。**<br /> 使用传统的同态加密算法难以避免同态加密本身存在的一些问题,受到训练过程缓慢、同态电路深度有限、客户端和云端通信等诸多限制。我们拟研究设计一种轻量级安全多方计算协议来加强隐私保护,并在此基础上引入周期性更新策略,使本系统能够在有效保护模型参数隐私的同时不增加通信开销。<br />(4)**抵抗用户训练掉线。**<br />研究增强算法在训练中的鲁棒性,以解决在不可信、不稳定网络环境下可能出现多用户互相共谋、用户在训练中突然掉线等问题。

2.2 场景问题分析

各个企业与团体在面对传统联邦学习时仍存在担心,数据孤岛问题成为了悬在中小型公司或团体在信息化、智能化时代发展的达摩克斯之剑,如何在小数据量、在隐私保护的前提下更好的利用在论文《Deep Leakage from Gradients》中,提到了一种关于梯度泄露攻击的算法,传统的、不加任何安全措施的泄露梯度是不安全的。如果我们不引入任何安全机制,通过多次梯度信息进行攻击恢复出参与训练客户端的样本特征的,不能将任何梯度信息暴露在第三方中(包括其他客户端与聚合服务器);
image.png
在系统设定中,每个客户端都是半诚实且好奇的,参预训练的客户端可能会为了寻求己方利益与其他客户端合谋,如果多个客户端对系统进行合谋攻击,某一方参与训练的客户端的信息很有可能就会泄露;

初始胡.png
图需要
在现实中,参预训练的训练参与方的网络环境并不是总是良好的,会存在网络带宽受限、网络截断、网络延迟等情况,客户掉线、通信资源有限的处理办法也是亟待解决的。

初始胡.png
😑😑😑图需要改
在边缘计算场景下,联邦学习的参与客户端的计算能力并不是类似于分布式机器学习是平均分配的;参与训练的客户端所拥有的数据量并不是平衡的;在非独立同分布数据中数据的类型分布也是不平衡的,在相同的训练参数设置下,数据量多、计算能力强的计算节点相较其它可能获得更好的训练结果,因此,如何设置统一的、合适的本地训练策略解决这些不平衡也是亟待解决的。
大体模型.png
😉😉😉😉😉😉图需要改,算力不平衡,数据不平衡

因此,怎样解决好面临的难题是本系统的核心重点。为达到系统初始设立的目标,本方案基于线性秘密共享算法为核心,签名认证算法为辅助,设计了使用了可信、可验证的安全多方计算框架的联邦学习方案。针对性的解决了在边缘计算环境下进行机器学习训练中面临的隐私保护问题。总体设计框架如下:
😀😀😀😀😀缺个图
对于边缘计算的机器学习训练场景,各个模块构建的基本联邦学习框架为机器学习训练寻求全局最优模型结果,对于训练、通信策略的设定提供了一种更好更快全局收敛与减少通信开销的解决办法,而对于基于线性秘密共享的安全多方计算框架则提供了较强的隐私保护功能:1.秘密共享算法提供了抵御共谋攻击、弹性应对用户掉线的可能;2.密码算法的应用提供了对信息机密性、完整性的保护。接下来将从安全多方计算方案、模块设计两层次介绍本方案。

2.3 安全多方计算方案设计

先导知识需要再精简一些,仔细一些

2.3.1 先导知识

2.3.1.1 秘密共享

秘密共享的思想是通过秘密分发算法将秘密分割,将分割得到的每个子份额分配给不同的参与者,只有合格的参与者集合才能通过密码重构算法恢复秘密,集合中存在非法参与者或参与者数量不达标都无法完成密码重构操作。该方案在大小至少为系统 - 图5(其中系统 - 图6是方案的安全性参数)的有限域F上进行参数化。例如,对于某些大的公共素数系统 - 图7系统 - 图8。 我们注意到需要这么大的字段,因为我们的方案要求客户保密共享他们的秘密密钥(其长度必须与安全性参数成比例才能通过安全性证明)。 我们还假设整数系统 - 图9(表示协议中的n个用户)可以用F中的不同域元素来标识。
给定这些参数,该方案由两种算法组成。 共享算法系统 - 图10将秘密系统 - 图11个用户ID的集合系统 - 图12和阈值系统 - 图13作为输入。它会产生一组份额系统 - 图14,每个份额都与一个不同的系统 - 图15相关联。重构算法系统 - 图16将阈值系统 - 图17和对应于系统 - 图18的子集系统 - 图19共享,并输出一个域元素系统 - 图20
正确性要求系统 - 图21并且系统 - 图22其中系统 - 图23,如果系统 - 图24系统 - 图25,则系统 - 图26。 安全性要求系统 - 图27和任意系统 - 图28使得系统 - 图29系统 - 图30其中”系统 - 图31“表示 两个分布是相同的。

2.3.1.2 认证加密

经过身份验证的加密结合了双方之间交换的消息的机密性和完整性保证。 它由三部分算法组成,分别是输出私钥的密钥生成算法,将密钥和消息作为输入并输出密文的加密算法系统 - 图32,以及将密文和密钥作为输入和原始明文或特殊错误符号⊥作为输出的的解密算法AE.dec。
为了正确起见,我们要求对于所有密钥 和所有消息x,AE.dec(c,AE.enc(c,x))=x。 为了安全起见,我们要求在[7]中定义的纯文本攻击和密文完整性下不可区分。对于任何在随机采样的密钥c下为其选择了消息加密的对手M(其中c对M未知),M不能区分两个不同消息在c下的新加密,M也不能创建新的 关于c的有效密文(不同于它收到的密文),其优点显而易见。

2.3.1.3 Diffie-Hellman密钥交换协议

密钥协议:
本作品密钥协议由一组密钥算法构成,分别为:1、KA.param(公共参数生成算法),该算法用于生成用户可共享的公共参数,即KA.param(k)=param,2、KA.gen(公-私钥对生成算法),通过此算法,用户可利用共享的公共参数生成一个公-私钥对Sk与Pk,即KA.gen(param)=(Sk,Pk),并选择性地同其它用户共享所拥有的公钥Pk,3、KA.agree(私有共享密钥生成算法),该算法允许用户X所持有的私钥Xsk与用户Y共享的公钥Ypk相结合,生成只在两个用户之间共享的密钥Sxy,即KA.agree(Xsk,Ypk)=Sxy。我们使用了迪菲赫尔曼密钥协议来实现这组算法,具体过程为:1、KA.param(k)=(G,g,q,h)函数从群G,生成器g和哈希函数H中随机抽样,2、KA.gen(G,g,q,h)=(t,g)从Zq中随机取样出t作为私钥Sk,以g作为公钥Pk,3、KA.agree(tx,g)=H((g))生成Sxy。在这个过程中,对于使用相同公共参数生成的不同公-私钥对,存在以下关系KA.agree(Xsk,Ypk)=KA.agree(Xpk,Ysk)。我们假设一个用户同时拥有X,Y用户共享公钥(Xpk,Ypk),出于安全性考虑,我们需要实现一种策略,使该用户无法从公钥中推出私有共享密钥Sxy,为了解决这一问题,可以使用决策迪菲赫尔曼假设(需不需要介绍?)。我们进一步假设这个用户可以获取KA.agree(Xsk,Rpk)和KA.agree(Ysk,Rpk)(R指除X,Y之外的任意用户),这一前提增大了Sxy泄露的风险,本作品使用甲骨文迪菲赫尔曼假设,设计出了更具安全性的策略,使得在这种情况下,该用户仍无法得到Sxy。(需要不要介绍?)

2.3.1.4 伪随机数发生器

我们需要一个安全的伪随机数发生器[9,54] PRG,它接收一定固定长度的均匀随机种子,其输出空间为[0,R)^m(即协议的输入空间)。 伪随机数发生器的安全性保证了只要种子对区分符是隐藏的,它在均匀随机种子上的输出就与输出空间的均匀采样元素在计算上是无法区分的。

2.3.1.5 签名方案

该协议依赖于标准UF-CMA安全签名方案(SIG.gen,SIG.sign,SIG.ver)。 密钥生成算法SIG.gen(k)\rightarrow\left(d{PK},d{SK}\right)以安全参数为输入,并输出私钥d{SK}和公钥d{PK}; 签名算法SIG.sign\left(d^{SK},m\right)\rightarrow\sigma作为密钥和消息的输入,并输出签名σ。 验证算法SIG.ver\left(d^{PK},m,\sigma\right)\rightarrow{0,1}将公钥,消息和签名作为输入,并返回指示该签名是否应被视为有效的位。 为了正确起见,我们要求\forall m,
Pr⁡dPK,dSK←&SIG.gen(k),σ←SIG.signdSK,m:&SIG.ver⁡dPK,m,σ=1=1
安全性要求,没有PPT对手,只要有一个诚实生成的新鲜公钥并可以访问在任意消息上生成签名的oracle,就应该能够在其上查询了oracle的消息上产生有效签名的可能性要小得多。不知道在说什么鬼

2.3.1.6 公钥基础结构

为了防止服务器模拟任意数量的客户端(在主动对手模型中),我们需要公钥基础结构的支持,该基础结构允许客户端注册身份并使用其身份对消息进行签名,以便其他客户可以验证此签名,但不能模拟它们。 在此模型中,每一方u都会在设置阶段将\left(u,d_u^{PK}\right)注册到公共公告板上。公告板将仅允许各方自行注册密钥,因此攻击方将无法冒充诚实方。

2.3.2 安全多方计算方案分析设计

其他密码算法应用应该在哪里说,叫安全多方计算方案设计合理吗

在本系统中,聚合服务器与训练客户端会进行部分通信,深刻剖析信息传输所面临的安全风险,必须使用相关信息传输方案设计确保信息安全。在信息传输中,我们面临着中间人攻击,信息的完整性、机密性需要密码算法进行保护,我们使用签名算法对传输信息的完整性、不可否认性进行保证,这样信息在传输过程中无法被非法篡改与否认;为了防止秘密共享算法中的秘密份额等机密信息泄露,我们使用Elgma公钥算法进行加密,确保秘密信息的机密性;在传输中,我们信息还可能会被第三方截取,并对信息进行重放,干扰联邦学习聚合过程,因而需要在传输的信息中加入时间戳等新鲜因子,以确保消息的新鲜性。针对可公开信息系统 - 图33和秘密信息系统 - 图34签名算法设计如下:
系统 - 图35
1.png
系统 - 图37
2.png
基于同态加密、差分隐私的训练手段都一定程度上牺牲了梯度的精度,为了减少加密带来的精度牺牲,我们可以

2.3.3 安全多方计算框架

本章不许要重构,但名字要起好,部分细节与图画需要写进去

总体上,本边缘联邦学习系统主要由可信第三方、聚合服务器、训练参与客户端三类参与方,可信第三方为聚合服务器、客户端提供身份公证服务;聚合服务器为参与客户端提供训练结果聚合服务及相应公钥广播服务;参与客户端进行本地训练。接下来将从不同层次阐明本系统安全方案设计。

可信第三方设立必要性分析

在联邦学习过程中,客户端与聚合服务器需要进行大量通信,为了保证信息完整性和真实性,需要对传输信息进行相关认证处理,保证消息在传递中未被未授权修改或者保证信息传递的完整一致性。所以,在系统设置中可信第三方的存在十分必要,可信第三方作为信任传递节点,可以进行证书签发、证书管理,将用户身份和用户公钥信息、公钥进行绑定,实现公钥与身份的唯一绑定关系。

客户端、聚合服务器联邦训练流程:

1、初始化阶段:

联邦学习前,需要系统设定一些必要参数与相关设置,确保满足系统传输与安全需要。

  • 由可信第三方TA生成客户端、聚合服务器公钥证书,进行身份公证,实现公钥与身份的唯一绑定关系,该公私钥对系统 - 图39主要用于系统进行相关身份认证、消息认证等。
  • 全网基准统一授时。
  • 规定一个有限域系统 - 图40用于秘密共享算法。
  • 由可信第三方TA生成一个安全素数系统 - 图41,该安全素数选取原则为:
    • 系统 - 图42必须具有大的素因子
    • 系统 - 图43的位数应该在1024位以上
    • 存在系统 - 图44系统 - 图45也为素数,同时满足系统 - 图46
  • 各个客户端与聚合服务器建立一个通信通路,在本实验中使用较为简单socket机制实现各个参与方的通信互联。
  • 设定秘密分享协议中的数值系统 - 图47和阈值系统 - 图48
  • 客户端协商使用何种训练模型。
  • 所有参与训练的客户端诚实的使用可信第三方TA提供的安全素数系统 - 图49生成DH密钥交换协议的相关参数系统 - 图50

    2、公钥共享:

    客户端:

  • 参与训练的客户端使用公钥生成算法生成两对公私钥对系统 - 图51,并且使用个人签名私钥对该公钥对进行签名系统 - 图52

  • 每一个客户端将生成的两对公钥系统 - 图53、签名系统 - 图54、时间戳,按照一定拼接方式拼合系统 - 图55,经base64编码后通过socket连接发送给聚合服务器。

聚合服务器:

  • 聚合服务器收集参与训练客户端发来的信息。
  • 聚合服务器在信息收集到一定时间后,且记录的训练客户端数目未达到系统 - 图56个及以上时,中断本次聚合训练;否则,聚合服务器广播训练客户端集合系统 - 图57和签名公钥信息系统 - 图58

    未命名绘图.png

3、秘密共享:

客户端:
  • 参与训练客户端收到聚合服务器广播的客户端集合系统 - 图60和签名公钥信息集合系统 - 图61,对签名公钥集合信息进行认证系统 - 图62,如若通过验证则进入下一步,如若未通过验证则中断训练。
  • 参与训练客户端使用随机数发生器系统 - 图63生成两个属于有限域 系统 - 图64 的随机数 系统 - 图65,系统 - 图66
  • 客户端在生成一个在有限域系统 - 图67上,次数为系统 - 图68的秘密多项式系统 - 图69,通过秘密共享算法生成秘密共享份额系统 - 图70
  • 客户端对于每一个其他客户端系统 - 图71,计算系统 - 图72
  • 客户端对密文系统 - 图73使用系统 - 图74杂凑算法取消息摘要并使用签名公钥系统 - 图75进行签名系统 - 图76
  • 客户端将消息系统 - 图77通过socket连接发送给聚合服务器。

聚合服务器:
  • 聚合服务器收集至少系统 - 图78条密文信息,记录参与训练客户端为集合系统 - 图79。否则中止本次训练。
  • 检查系统 - 图80与否,如若不成立,终止,否则,继续下一步.
  • 广播签名秘密信息系统 - 图81与客户端集合系统 - 图82

    4、掩码

    客户端:
  • 参与训练客户端从聚合服务器接收(并存储)密文列表系统 - 图83(并推断集合系统 - 图84)。如果列表大小小于t,则中止。

  • 客户端检验密文列表信息的完整性,并对密文签名进行认证,确保信息来源正确.
  • 客户端解密密文系统 - 图85,将其他客户端系统 - 图86发送的秘密份额系统 - 图87和公共掩码系统 - 图88进行保存。
  • 参与训练客户端使用系统 - 图89随机数生成器计算私有掩码向量系统 - 图90系统 - 图91,系统 - 图92.
  • 计算加入掩码的掩码梯度向量系统 - 图93
  • 对掩码进行压缩并使用签名公钥进行签名,系统 - 图94
  • 如果上述任何过程失败,则中止训练,否则,将系统 - 图95发送到服务器并移动到下一轮。

    聚合服务器:
  • 服务器接受至少t条签名掩码梯度向量信息,并将这些客户端记录为集合系统 - 图96,否则终止本次训练。

  • 检查系统 - 图97与否,如若不成立,终止,否则,继续下一步.
  • 广播签名掩码梯度向量信息系统 - 图98与客户端集合系统 - 图99

    5、一致性检验

    客户端
  • 客户端从服务器接收至少由系统 - 图100个用户(包括其自身)组成的列表系统 - 图101。如果系统 - 图102小于系统 - 图103,则中止。

  • 发送系统 - 图104到服务器。

    聚合服务器

    从至少系统 - 图105个用户处收集系统 - 图106(用系统 - 图107表示这组用户)。向系统 - 图108中的每个用户发送集合系统 - 图109

    6、解私人掩码

    客户端:
  • 从服务器接收集合系统 - 图110。验证系统 - 图111系统 - 图112,根据每个客户端的签名验证公钥进行验证系统 - 图113表示所有系统 - 图114(否则中止)。

  • 向聚合服务器发送自己所有的能够重构秘密的秘密份额份额系统 - 图115系统 - 图116

聚合服务器
  • 收集至少系统 - 图117个用户的响应,用系统 - 图118记录这些客户端。
  • 对于每个用户系统 - 图119,重构系统 - 图120,然后使用随机数生成器系统 - 图121重新计算掩码向量系统 - 图122
  • 计算并输出系统 - 图123系统 - 图124
  • 聚合服务器使用签名公钥将压缩后的向量签名系统 - 图125,将系统 - 图126进行广播。

    7、公共掩码求解

    客户端:
  • 参与训练的客户端收到聚合服务器发来的信息,并进行签名验证。

  • 客户端计算联邦聚合梯度系统 - 图127。判断训练精度与训练损失是否达到预期,如若没有,客户端做好进行下一次联邦学习准备;否则本次联邦学习目的达到,退出训练集群。

    2.4 系统模块设计

    2.4.1 本地数据管理模块

    数据分布描述的是数据的统计状态,根据数据样本的不同分布情况,可以将数据分为IID数据和Non-IID数据。IID(Independent Identically Distribution)数据指在数据集中,所有的数据样本都服从于同一分布,并且样本和样本之间相互独立。区别于IID数据,在Non-IID(Non - Independent Identically Distribution)数据集中,数据样本之间非独立,非同分布。
    传统的分布式机器学习利用IID数据进行模型训练,训练得到的模型是基于数据独立同分布的假设之上,然而在实际的边缘计算场景中,计算设备属于不同的个体、企业,由于不同客户端时间和空间等方面的差别,其数据的分布往往具有很大的差异,同时用户群体和地域的关联又使数据的分布存在一定的联系,此时数据不满足独立同分布,即为Non-IID数据,可以看出Non-IID数据是更符合实际边缘计算的应用场景的。
    在本边缘计算背景下非独立同分数据有如下特点:

  • 标签分布偏斜(先验概率偏移):即使系统 - 图128相同,边际分布系统 - 图129也会因客户端而异。例如当客户绑定到特定地理区域时,标签在客户之间的分布会有所不同:对于袋鼠只在澳大利亚或动物园;一个人的脸只在全球的少数地方出现;对于移动设备键盘,某些人群使用某些表情符号,而其他人群则不使用。

  • 数量偏斜或不平衡:处在不同环境下的客户端所采集到的数据集时远远不同的:布设在地铁站的人脸识别设备可以捕捉到许多人的人脸数据,而放在小区电梯中的人脸识别设备只能捕捉数量相对有限的人脸数据。

为了突出体现两种数据在联邦学习模型训练中的差异,在独立同分布数据集(IID)下,该联邦学习系统可以获得较传统分布式学习近似相同的准确性;在非独立同分布数据集(Non-IID)下,该使用安全多方计算框架的联邦学习可以获得较传统联邦学习近似的全局准确度。该在实验中我们使用了MNIST和CIFAR数据集,用两种划分手段将60000个数据样本分别划分为IID和Non-IID格式,并将这些样本分配给100个用户。在IID格式中,我们将数据打乱,按照模拟的客户端数量系统 - 图130随机分成系统 - 图131份独立同分布数据集(IID)供客户端进行训练。在Non-IID格式中,我们首先利用数字标签将数据排序,将其按照系统模拟客户端数量进行按标签划分。

  1. # noniid数据划分举例
  2. def mnist_noniid(dataset, num_users):
  3. print('minist noniid')
  4. """
  5. 对Mnist数据集进行noniid数据划分
  6. :param dataset:
  7. :param num_users:
  8. :return:
  9. """
  10. num_shards, num_imgs = 200, 300
  11. idx_shard = [i for i in range(num_shards)]
  12. dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}
  13. idxs = np.arange(num_shards*num_imgs)
  14. labels = dataset.train_labels.numpy()
  15. # 标签排序
  16. idxs_labels = np.vstack((idxs, labels))
  17. idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]
  18. idxs = idxs_labels[0,:]
  19. # 划分和指派
  20. for i in range(num_users):
  21. rand_set = set(np.random.choice(idx_shard, 2, replace=False))
  22. idx_shard = list(set(idx_shard) - rand_set)
  23. for rand in rand_set:
  24. dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
  25. return dict_users

2.4.2 本地训练模块

本地训练模块三层介绍内容太过于空洞,训练算法?分布式梯度下降?周期性更新?怎末写

应用了联邦学习的边缘计算背景下的机器学习训练,可以获得较好的全局模型,但这并不代表本地训练算法设计、训练策略设计不重要。在本作品中,模拟面对不同的应用场景,提供了两种可供选择的神经网络:结构较为简单的MLP(Multilayer Perceptron) 和结构较为复杂的CNN(Convolutional Neural Network)。MLP神经网络也叫人工神经网络,它包含输入层、输出层和多个中间隐藏层,层与层之间为全连接,MLP的具体结构如图所示
image.png
在隐藏层中,每个神经元上包含对于一个输入的权值,一个偏置,以及一个激活函数。数据由输入层进入到隐藏层进行运算(乘上权值,加上偏执,激活函数运算一次)。每一个隐藏层计算得到的结果作为下一层的输入,直到进入输出层并得到输出结果。MLP神经网络构建代码如下:

  1. class MLP(nn.Module):
  2. def __init__(self, dim_in, dim_hidden, dim_out):
  3. super(MLP, self).__init__()
  4. self.layer_input = nn.Linear(dim_in, dim_hidden)
  5. self.relu = nn.ReLU()
  6. self.dropout = nn.Dropout()
  7. self.layer_hidden = nn.Linear(dim_hidden, dim_out)
  8. def forward(self, x):
  9. x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
  10. x = self.layer_input(x)
  11. x = self.dropout(x)
  12. x = self.relu(x)
  13. x = self.layer_hidden(x)
  14. return x

MLP适用于分类预测、回归预测等问题,在二维图像处理中,CNN比MLP更加有效,因此我们另提供了CNN神经网络,以满足不同用户的需求。CNN网络包括三个基本层:卷积层,池化层和全连接层,数据在前向传播的过程中,会在卷积层和池化层中进行多次卷积和池化处理并提取出特征向量,这些特征向量被输入到全连接层中,这个过程会一直重复,直到得到符合期望的结果。CNN神经网络结构如下图:
image.png

  1. <br />神经网络最左侧为卷积层,数据经过卷积计算进入池化层。池化分为三种类型,分别为最大池化(max pooling),平均池化(average pooling)和其他如L2池化等,在本作品使用的池化类型为最大池化,即取局部接受数据中值最大的点。<br />针对MNIST数据集和CIFAR数据集,我们搭建了两种CNN神经网络,神经网络架构具体代码如下:
  1. # Mnist CNN
  2. class CNNMnist(nn.Module):
  3. def __init__(self, args):
  4. super(CNNMnist, self).__init__()
  5. self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
  6. self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
  7. self.conv2_drop = nn.Dropout2d()
  8. self.fc1 = nn.Linear(320, 50)
  9. self.fc2 = nn.Linear(50, args.num_classes)
  10. def forward(self, x):
  11. x = F.relu(F.max_pool2d(self.conv1(x), 2))
  12. x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
  13. x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
  14. x = F.relu(self.fc1(x))
  15. x = F.dropout(x, training=self.training)
  16. x = self.fc2(x)
  17. return x
  18. # Cifar CNN
  19. class CNNCifar(nn.Module):
  20. def __init__(self, args):
  21. super(CNNCifar, self).__init__()
  22. self.conv1 = nn.Conv2d(3, 6, 5)
  23. self.pool = nn.MaxPool2d(2, 2)
  24. self.conv2 = nn.Conv2d(6, 16, 5)
  25. self.fc1 = nn.Linear(16 * 5 * 5, 120)
  26. self.fc2 = nn.Linear(120, 84)
  27. self.fc3 = nn.Linear(84, args.num_classes)
  28. def forward(self, x):
  29. x = self.pool(F.relu(self.conv1(x)))
  30. x = self.pool(F.relu(self.conv2(x)))
  31. x = x.view(-1, 16 * 5 * 5)
  32. x = F.relu(self.fc1(x))
  33. x = F.relu(self.fc2(x))
  34. x = self.fc3(x)
  35. return x

在现实中,拥有较多数据的计算节点可能在相同的时间内能训练出更好的模型,因此,数据量较少的计算节点训练的模型梯度可能会拖慢全局梯度收敛速度。针对这种现象,我们设定了较长的周期性更新策略,让每个计算节点发送给聚合服务器的本地梯度能够充分收敛,这一点的改变使得让全局梯度快速聚合,同时节约了通信开销。

2.4.3 通信模块

在边缘计算场景下,本地训练结果需要进行可靠的通信方法传回聚合服务器,我们选取了socket机制进行通信。Socket(套接字)可以看成是两个网络应用程序进行通信时,各自通信连接中的端点,这是一个逻辑上的概念。它是网络环境中进程间通信的API(应用程序编程接口),也是可以被命名和寻址的通信端点,使用中的每一个套接字都有其类型和一个与之相连进程。通信时其中一个网络应用程序将要传输的一段信息写入它所在主机的 Socket中,该 Socket通过与网络接口卡(NIC)相连的传输介质将这段信息送到另外一台主机的 Socket中,使对方能够接收到这段信息。在实现中,我们使用socketIO_client库中的类实现通信,使用其中的emit方法实现各个阶段事件的触发,一轮联邦训练的总体通信过程为:
未命名绘图.png

2.4.4 全局聚合模块

放在模块介绍最前面是不是更好,借这个讲明使用总体联邦学习算法

为了在边缘计算背景下进行的机器学习训练得到全局最优的模型,我们采用FederatedAvering梯度平均,在谷歌论文《Communication-Efficient Learning of Deep Networks from Decentralized Data》中提出了该算法,主要分为梯度下降
worker端(重复以下步骤)
step1: 各个节点从server端得到最新的参数系统 - 图135;
step2: 对于各个worker, 利用本地的数据和参数系统 - 图136,计算当前节点上样本的梯度系统 - 图137;
step3: 向server发送该梯度系统 - 图138;
server端(重复以下步骤)
step1: 从worker上获得系统 - 图139
step2: 更新参数系统 - 图140
应用在边缘计算场景下,FederatedAvering算法将从云端得到的参数下发给每个边缘计算终端设备,终端设备利用本地数据和参数在网络边缘侧计算出梯度,并将梯度值反馈给云端。
各个终端设备梯度汇总到云端后进行加权平均,得到最终的参数,也就达到了模型聚合的效果。

  1. import copy
  2. import torch
  3. def FedAvg(w):
  4. w_avg = copy.deepcopy(w[0])
  5. for k in w_avg.keys():
  6. for i in range(1, len(w)):
  7. w_avg[k] += w[i][k]
  8. w_avg[k] = torch.div(w_avg[k], len(w))
  9. return w_avg

2.5 方案分析