联邦学习(Federated Learning)
一、应用场景
从以上两个场景可以看出当前机器学习领域对于知识共享、隐私保护两大需求所面临的挑战:数据量越大、越全面,训练出的模型效果越佳,但单个用户设备数据量小,训练出来的模型效果不佳,故服务器需集合大量用户数据来训练模型;服务器要集合用户数据,需每个设备进行数据共享,服务器需和用户设备频繁通信,这样易造成数据泄露,想要保护隐私,本地数据不能离开设备。
要意识到这样一个现状:欧美国家的法律对用户隐私保护非常严格,但我国的个人信息获取渠道泛滥,网站/app均需获取个人信息,但人们不以为然。因此,面对以上两个场景,更智能or更隐私之间怎么选择呢?能否两者兼得呢?这时联邦学习来了,他的学习目的便是训练出好模型的同时能保护隐私。在介绍联邦学习前,首先得了解下分布式机器学习,因为联邦学习的工作思路就是基于分布式机器学习拓展开来的。
二、分布式机器学习
分布式机器学习是搭建在多设备、多节点的分布式集群上的,底层支持CPU、GPU等多种设备,并提供自适应的任务调度能力,合理地利用计算资源,来完成深度学习模型的训练并获取好的收敛,以达到训练速度与精度的平衡。分布式机器学习涉及到训练数据的分布式存储、计算任务的分布式操作、模型结果的分布式分布等多个方面。传统的分布式机器学习在服务器上训练模型,客户端对服务器发起请求,进行预测,因此,所有模型都会保存在服务器上。
分布式机器学习的主要思路是将每台用户设备当作一个worker节点,每个worker各自计算出梯度,各自把梯度发给server,server由此可学出来一个模型,且数据没有离开worker,控制权还在用户手上,server看不到用户数据,没有违反隐私限制。要注意,worker与server间的通信复杂度就是模型参数的数量。
如上图,可概括分布式机器学习的迭代流程如下:
step1:worker向server索要参数
step2:server将最新的参数发给worker
step3:worker使用本地的数据进行本地计算,算出梯度(计算量大)
step4:worker将梯度发给server
step5:server用梯度来更新模型参数(计算量很小)
三、联邦学习
联邦学习是谷歌2016年提出的一种分布式机器学习框架,其主要思路是建立基于多个设备上数据集的机器学习模型,同时防止数据泄露。联邦学习是一个学习过程,在这个过程中,各个数据所有者协作地训练一个模型m;在这个过程中,任何数据所有者都不向其他人公开自身的数据,同时也使训练出来的模型的精度、性能接近传统的模型m。传统的模型训练方法是将所有数据放在一起,训练一个模型,而联邦学习无需将数据集中在服务器中,这样保护了隐私。下图是联邦学习的基本工作机制。
联邦是由很多州组成,每个州拥有高度自治权,组成的联邦是个比较松散的政府,联邦学习就是将用户设备比作州,将中心服务器比作联邦政府,每个用户设备都有高度自治权,自己掌控着自身的数据信息。联邦学习是一种特殊的分布式机器学习,两者没有什么本质区别,主要是应用场景的不同,分布式机器学习主要是应用在机房服务器设备上,而联邦学习主要是应用在手机、平板这样的移动端。联邦学习的应用很新,但方法上相对于分布式机器学习而言没有太多改变。下面是联邦学习和分布式机器学习主要的区别:
(1)用户对自己的设备和数据有绝对的控制权
用户可以随时让自己的设备(比如手机)停止参与计算和通信。诸如mapreduce等传统的分布式系统中,worker受server的控制,接受server的指令。
(2)参与联邦学习的worker节点不稳定
联邦学习worker节点大多都是手机、平板、智能家居等设备,这些设备不稳定,计算性能差别也大(比如iphone11和iphone5同时开始计算);传统分布式机器学习的worker大多都是机房中连着高速宽带24小时开机的机器,有专人维护,非常稳定,计算性能也几乎是一样的。
(3)联邦学习的通信代价非常大
联邦学习的通信代价远大于计算的代价,因为手机等移动端都基本都是远程连接服务器(甚至设备和服务器不在一个国家),所以带宽低,网络延迟高。而分布式机器学习基本都是网线连接着高速宽带进行通信的。
(4)联邦学习的数据并非独立同分布
数据并发,每个worker都有一部分数据,由于用户使用习惯不一样(比如拍风景和自拍),每个手机用户的数据统计性质是不一样的。由于数据不是独立同分布,很多已有的减少通信次数的算法不适用了。而传统的分布式机器学习中,数据的划分是均匀的,随机打乱的。
(5)联邦学习的节点负载不平衡
节点上的数据有些大有些小(比如有些人一天拍几十张照片,有些人几十天拍一张照片),建模和计算容易出问题:若给每个用户不同的权重,所建模型对拍照多的用户有用,而拍照少的用户几乎被忽略掉了;若给每个用户相同的权重,这样学出来的模型对拍照多的用户不太好。传统分布式计算都会做负载平衡,但联邦学习无法做负载平衡(不能把一个用户的数据转移到另一个用户的设备上)。
注:其中(2)至(5)即为联邦学习技术上的难点。
四、联邦学习的实现
1.算法实现
在介绍联邦学习的算法前需要明确一个点,也就是第三章中联邦学习与分布式学习的第(3)点区别:联邦学习的通信代价非常大。选择联邦学习的实现算法,要以减少通信次数为最主要的方向,哪怕不惜加大计算量,毕竟计算的代价远远低于通信的代价(这也会在第七章说明)。整个联邦学习的迭代考虑两个算法:并行梯度下降和Federated Averaging。
(1)并行梯度下降
如上图所示,w是参数,gi是梯度,m为节点个数,α是学习率(步长),使用梯度下降更新参数w,具体迭代流程大致可归纳为以下几步:
step1:server把最新的参数w发给worker
step2:worker计算梯度,使用加密技术屏蔽梯度选择,把屏蔽后的梯度传回server
step3:server执行安全聚合,再更新参数,不需要了解任何分区信息
step4:迭代,直至算法收敛
step5:server将汇总的结果(加密)返回给worker
step6:worker使用解密的梯度更新自身的模型参数,得到最终模型
(2)Federated Averaging算法
如上图所示,此算法能用更少的通信次数就能达到收敛,这样做的好处是可在两次通信之间可以把参数做很大的改进,而不仅仅是一次梯度下降。具体迭代流程大致可归纳为以下几步:
step1:server把最新的参数w发给worker
step2:worker使用参数w和本地数据去计算梯度g,在本地做梯度下降
step3:本地更新参数,重复这个步骤几个epoch(1~5个)
step4:把最终本地得到的参数wi发给server
step5:server将所有worker发过来的参数做平均,最后平均出来的参数作为新参数。
这里有一个Federated Averaging算法应用的实例:我们有一群客户端,他们有温度传感器,我们这里想计算客户端中的温度有多少比例超过阈值,这种计算中有两个输入——一个是客户端中的温度读数,另一个是服务器中的阈值,具体步骤如下:
step1:服务器向所有客户端广播阈值
step2:客户端接收到阈值后,通过自身温度读数和阈值进行tensorflow计算,超过阈值为1,不超过为0
step3:拥有很多1和0后,通过Federated Averaging算法,执行分布式聚合计算,求这些1和0的平均值,并将结果发给服务器。
(3)算法对比
如上图所示。横坐标为通信次数,纵坐标为损失函数,对比两个算法可发现,在相同通信次数的情况下,FedAvg算法收敛更快。因此在达到相同收敛下,FedAvg所需的通信次数更少,达到了降低通信次数的目的。FedAvg以更大的计算量为代价,换取更少的通信次数。
如上图所示,换个角度对比,横坐标为epochs,纵坐标为损失函数,可见在经历相同epochs计算的情况下,梯度下降收敛更快。因此要达到收敛,FedAvg所需的计算量大于梯度下降。
由于联邦学习中的计算代价小,通信代价大,因此FedAvg是更优的算法。
2.隐私保护
联邦学习可以看作是一种保护隐私的分散协作机器学习。隐私保护是联邦学习的重点之一,目前所用到的隐私保护方法主要有以下4种:
a)安全多方计算:提供安全证明,加密通信过程
b)差分隐私法:即向参数中添加一些随机噪声
c)同态加密:参数交换
d)区块链联邦学习架构:设备的本地模型更新通过区块链进行交换和验证
五、联邦学习的分类
构建一个联邦学习矩阵,矩阵的每一行表示一个样本,每一列表示一个特征。构成该矩阵的三要素分别为:样本空间I(用户),特征空间X(业务、属性),标签空间Y(指标,例如学位)。依据联邦学习矩阵结构及应用场景的特点,大致可将联邦学习分为三个类别:水平联邦学习、垂直联邦学习、联邦迁移学习。
(a)水平联邦学习
水平联邦学习适用于数据集共享的特征空间相同但在样本空间不同的场景。例如两个区域(如长沙和拉萨)的银行有来自各自区域不同的用户,并且他们的用户交集非常小,但他们的业务非常相似(如相似的信贷业务、理财产品),所以他们的特征空间是相同的。用户独立训练,只共享参数更新的子集。图(a)是水平联邦学习的矩阵结构。
(b)垂直联邦学习
垂直联邦学习适用于两个数据集共享的样本空间相同但特征空间不同的场景。例如,同一个城市的两个不同的公司,一个是银行,一个是电商公司,处在同一个城市的他们用户交集很大,故可看作有相同的样本空间。银行记录了收支行为和信用评级,电商公司保留了用户的浏览及购买历史,故他们的特征空间有很大的不同。将这些不同的特征聚集起来,以隐私保护的方式计算损失函数和梯度,从而用双方的数据协作构建一个模型。图(b)是垂直联邦学习的矩阵结构。
(c)联邦迁移学习
联邦迁移学习适用于两个数据集在样本空间和特征空间上均不同的场景。例如,一家是位于中国的银行,另一家是位于美国的电商公司,一方面,他们的用户群只有一个小的交集,另一方面,双方的特征空间只有很小部分重叠。面对这种情况,需运用联邦迁移学习——利用有限的公共样本集来学习两个特征空间之间的公共表示,然后用它来获得只有单边特征的样本预测。图(c)是联邦迁移学习的矩阵结构。
六、联邦学习应用领域
1.智能零售
目的是利用机器学习技术为客户提供个性化服务。智能零售业务涉及的数据特征主要包括用户购买力、用户个人偏好和产品特征,这三个特征可能分散在三个不同的部门或企业中,例如一个用户的购买力可从银行存款中推断,她的个人偏好可从她的社交网络中分析出来,而产品的特征则通过电子商店记录下来。这种情况面临两个问题:为了保护数据隐私和数据安全,银行、社交网络和电商网站之间的数据壁垒难以打破,故不能直接聚合他们的数据在一起进行训练;三方存储的数据通常是异构的,传统的机器学习模型不能直接处理异构数据。
联邦学习和迁移学习是解决这两个问题的一个好方法。联邦学习能够让各自数据不离开企业的前提下建立一个好的机器学习模型,这样保护了数据隐私。此外,迁移学习能够解决数据异构问题。因此联邦学习为我们构建跨企业、跨数据、跨领域的大数据和人工智能生态圈提供了良好的技术支持。
2.智能金融
一些用户从一家银行借款来还另一家银行的贷款,要找到这样的用户,即要找到两家银行的样本交集,但银行间的用户列表数据肯定是不能共享的,这时可通过联邦学习框架的加密机制,对各方的用户列表进行加密,然后通过诸如政府机构这样的第三方,聚合加密列表取其交集,最后通过解密得到需要的交集用户。
3.智能医疗
比如各大医院面对某种病毒的病情,需要尽可能多的相关数据(症状、基因序列、医疗报告),从而训练出更好的模型,但很多私密医疗数据难以收集,也不宜共享。联邦学习解决了这个问题,中心服务器可以通过收集各个医院的加密梯度来更新模型参数。
七、研究方向
1、降低通信次数
联邦学习最重要的研究方向,哪怕计算量会大很多,但减少通信次数是值得的。降低通信次数的算法理念基本可概括为“多做计算,少做通信”:获得相比于梯度更好的下降方向,迭代次数也远少于梯度的迭代,迭代次数少了则通信次数也少了。目前,降低联邦学习通信次数最常用的算法便是Federated Averaging算法(详情可参考第四章)。
2.隐私保护
虽然整个联邦学习流程中数据没有离开过worker,但并不是绝对的隐私保护。用户的数据发生了间接泄露:随机梯度的计算就是把用户本地的数据做了个函数变换,把数据映射到了梯度。
如上图所示,l为损失函数,w为模型参数,xi为数据,yi为标签,gi为梯度(向量),可见随机梯度就是把数据xi做了个放缩,等于梯度就是把数据做了个变换,梯度几乎携带了数据所有的信息,另外也可以通过模型参数反推出数据。
如上图所示,对于梯度,使用适当的分类器,是可以反向推断出用户属性的。随机矩阵变换能在一定程度上抵制数据的反推,是个值得研究的方法。
如上图所示,联邦学习以往的方法是加噪声,但噪声加少了隐私保护效果不好,加多了模型学习的不好。因此,在保证模型效果的情况下,联邦学习隐私泄露比较容易,而目前没有特别好的方法,这是个研究方向。
3.鲁棒性
如上图所示。拜占庭将军问题是针对分布式系统中出现异常节点的情况,这个异常节点会给其他节点发送错误的信息,故需解决拜占庭将军问题和恶意攻击。目前大部分防御方法都是假设数据是独立同分布的,所以用在联邦学习上效果不怎么好,因为联邦学习应用场景的数据不是独立同分布的。目前还没有特别有效的防御,故其鲁棒性是个研究方向。
八、总结
如上图所示,联邦学习的特点可以归纳为以下三点:
1)联邦学习的目标是让多个用户一起协同训练出一个模型,但这些用户不共享数据,用户数据不离开本地,不会去到第三方。
2)联邦学习是一种特殊的分布式机器学习,其主要思路是使用梯度进行交流而不是数据本身。
3)联邦学习的比普通的分布式机器学习困难,因为联邦学习的数据不是独立同分布的,同时还面临着一些其他挑战,比如节点的负载不均衡等。
联邦学习主要有三个研究方向,如下图。
从企业层面上来看,联邦学习可以在保护本地数据的同时,为多个企业建立统一的模型,使企业以数据安全为前提,共同取胜。相信在不久的将来,联邦学习能打破行业间的壁垒,建立一个数据和知识可以安全共享的社区。
参考文献
[1]《Federated Machine Learning: Concept and Applications》
[2]https://www.bilibili.com/video/av82029976?from=search&seid=12742299253974066406
[3]https://www.bilibili.com/video/av54176168?from=search&seid=676454166257594864
注:部分文字、图片来自网络,如涉及侵权,请及时与我们联系,我们会在第一时间删除或处理侵权内容,电话:4006770986。