6.3 推荐模型离线训练之 Parameter Server
背景:分布式可扩展的 Parameter Server 方案的提出,几乎完美地解决了机器学习模型的分布式训练问题
分布式训练问题(Spark 的并行梯度下降过程相对低效):
- 采用同步阻断式的梯度下降策略
- 每次迭代都需要 mater 节点将模型权重参数通过全局广播的形式发送到各 worker 节点
- 单 master 节点导致宽带瓶颈和内存瓶颈
- 同步地广播发送所有权重参数,使系统的整体网络负载非常大
Parameter Sever 的解决方案:
- 用异步非阻断式的分布式梯度下降策略替代同步阻断式的梯度下降策略
- 多 server 节点的协同:
- 实现多 server 节点的架构,避免单 master 节点带来的带宽瓶颈和内存瓶颈
- 使用一致性哈希、参数范围拉取、参数范围推送等工程手段实现信息的最小传递,避免广播操作带来的全局性网络阻塞和带宽浪费
6.3.1 Parameter Server 的分布式训练原理
假设一个机器学习问题的损失函数为:
- n 是样本总数, 是样本 i 的损失函数
- 是正则化项
为了求解 ,使用梯度下降法
Parameter Server 的主要作用:并行进行梯度下降的计算,完成模型参数的更新直至最终收敛
- server 节点:保存模型参数、接受 worker 节点计算出的局部梯度、汇总计算全局梯度,并更新模型参数
- worker 节点:保存部分训练数据,从 server 节点拉取最新的模型参数,根据训练数据计算局部梯度,上传给 server 节点
Parameter Server 的分布式训练流程:
- push 操作:worker 节点利用本节点上的训练数据,计算好局部梯度,上传给 server 节点
- pull 操作:为了进行下一轮的梯度计算,worker 节点从 server 节点拉取最新的模型参数到本地
- 每个 worker 载入一部分训练数据
- worker 节点从 server 节点拉取(pull)最新的相关模型参数
- worker 节点利用本节点数据计算局部梯度
- worker 节点将局部梯度推送(push)到 server 节点
- server 节点汇总局部梯度,计算全局梯度,更新模型
- 跳转到步骤 2,直到迭代次数达到上限 or 模型收敛
详见书 p192 图 6-10 & p190 代码 6-1
6.3.2 一致性与并行效率之间的取舍
两种分布式梯度下降策略:
- 同步阻断式的分布式梯度下降策略:
- 优点:是一致性最强的分布式梯度下降方法,其计算结果与串行梯度下降的的计算结果严格一致
- 缺点:需要所有 worker 节点都计算好局部梯度,push 给 server 节点,由 master( server)节点汇总梯度,计算好新的模型参数后,才能开始下一轮的梯度计算。因此,最慢的 worker 节点会阻断其他所有节点的梯度更新过程
- 异步非阻断式的分布式梯度下降策略:
- 优点:加快了训练速度:其他 worker 节点计算梯度的进度不会影响本节点的梯度计算。所有节点都在并行工作,不会被其他节点阻断。
- 如书 p 193 图 6-11 所示,第 10 轮迭代后的 push&pull 过程并没有结束,但第 11 轮迭代就已经开始了,这时候各个 workers 并没有拉取最新的模型权重参数,仍在使用第 10 轮的权重参数计算第 11 轮的梯度
- 缺点:导致了模型一致性的损失,并行训练的结果与原来的单点串行训练的结果是不一致的,会对模型收敛的速度造成一定影响
- 优点:加快了训练速度:其他 worker 节点计算梯度的进度不会影响本节点的梯度计算。所有节点都在并行工作,不会被其他节点阻断。
同步和异步的折中:可以通过设置“最大延迟”等参数限制异步计算的程度,eg. 限定在三轮迭代之内,模型参数必须更新一次,否则 worker 节点需要停下等待 pull 操作的完成
Parameter Server 用异步非阻断式的分布式梯度下降策略替代同步阻断式的梯度下降策略,大幅加快了训练速度
6.3.3 多 server 节点的协同和效率问题
Parameter Server 解决单点 master 效率低下问题的方法:采用了服务器节点组内多 server 的架构,每个 server 主要负责部分模型参数(模型参数使用 key-val 的形式,每个 server 负责一个参数键范围内的参数更新即可)
- 新的问题:
- 每个 server如何决定自己负责哪部分参数范围?
- 如果有新的 server 节点加入,如何在保证已有参数范围不发生大的变化的情况下加入新的节点?
- 解决办法:使用一致性哈希管理参数
- 参数范围拉取:当某个 worker 节点希望拉取新的模型参数时,该节点将发送不同的范围拉取请求到不同的 server 节点,之后各 server 节点可以并行地发送自己负责的权重参数到该 worker 节点
- 参数范围推送:某 worker 节点计算好自己的梯度后,只需要利用范围推送操作把梯度发送给一部分相关的 server 节点