【问题标题】:Learning rate setting when calling the function tff.learning.build_federated_averaging_process调用函数 tff.learning.build_federated_averaging_process 时的学习率设置
【发布时间】:2020-08-03 12:39:40
【问题描述】:

我正在执行一个联邦学习过程,并使用函数 tff.learning.build_federated_averaging_process 来创建一个联邦学习的迭代过程。正如 TFF 教程中提到的,这个函数有两个参数,分别称为 client_optimizer_fn 和 server_optimizer_fn,在我看来,它们分别代表客户端和服务器的优化器。但是在 FedAvg 论文中,似乎只有客户端进行优化,而服务器只进行平均操作,那么 server_optimizer_fn 到底是做什么的,它的学习率是什么意思?

【问题讨论】:

    标签: tensorflow-federated


    【解决方案1】:

    McMahan et al., 2017 中,客户端将本地训练后的模型权重传送到服务器,然后将其平均并重新广播到下一轮。不需要服务器优化器,平均步骤会更新全局/服务器模型。

    tff.learning.build_federated_averaging_process 采用了稍微不同的方法:客户端收到的模型权重delta,本地训练后的模型权重被发送回服务器。这个 delta 可以看作是一个伪梯度,允许服务器使用标准优化技术将其应用于全局模型。 Reddi et al., 2020 深入研究了这个公式以及服务器上的自适应优化器(Adagrad、Adam、Yogi)如何大大提高收敛速度。使用不带动量的 SGD 作为服务器优化器,学习率为1.0,完全恢复了 McMahan et al., 2017 中描述的方法。

    【讨论】:

      【解决方案2】:

      感谢您的回答扎卡里。在 McMahan et al., 2017 中,引入了两种实现联邦学习的方法,一种是计算客户端的平均梯度并将其发送到服务器进行平均操作,另一种是将平均梯度应用于每个客户端模型并发送客户端模型到服务器进行平均操作。 McMahan et al., 2017 的算法 1 使用第二种方式来实现联邦学习,而 TFF 根据您的回复使用第一种方式。让我困惑的是,在我看来,无论 TFF 使用哪种方式,都应该只有一个学习率,也就是说,第一种方式应该只有服务器 lr,没有客户端 lr,第二种方式应该只有客户端 lr 和没有服务器 lr。正如 McMahan et al., 2017 中提到的,只有一个符号 Eta 来表示 lr,没有 Eta_client 或 Eta_server。

      【讨论】:

      • 也许 McMahan 中提到的两种方法是 FedAvg 算法和 FedSGD 算法? FedSGD 在不更新客户端模型的情况下计算梯度,而 FedAvg 在发回新模型(或模型增量)之前在本地执行许多 SGD 步骤(更新客户端模型)。你说得对,前者只有一个服务器学习率。在论文中,后者有效地具有1.0的服务器学习率;不缩放客户端更新(也是 tff.learning.build_federated_averaging_process 的默认值)。
      最近更新 更多