【问题标题】:Do we need dataset in each of the worker when using ParameterServerStrategy?使用 ParameterServerStrategy 时,我们是否需要每个 worker 中的数据集?
【发布时间】:2026-01-06 22:40:01
【问题描述】:

在来自tensorflow API的ParameterServerTraining的教程代码中,在model.fit部分有如下代码sn-p

def dataset_fn(input_context):
  global_batch_size = 64
  batch_size = input_context.get_per_replica_batch_size(global_batch_size)

  x = tf.random.uniform((10, 10))
  y = tf.random.uniform((10,))

  dataset = tf.data.Dataset.from_tensor_slices((x, y)).shuffle(10).repeat()
  dataset = dataset.shard(
      input_context.num_input_pipelines,
      input_context.input_pipeline_id)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(2)

  return dataset

dc = tf.keras.utils.experimental.DatasetCreator(dataset_fn)

也有人说

The code in dataset_fn will be invoked on the input device, which is usually the CPU, on each of the worker machines.

这是否意味着数据集必须在每个工作服务器的相同存储上(比如参数服务器和工作服务器是不同的机器)?

或者有没有什么方法可以让一台机器上的参数服务器将训练数据发送给工人,而不需要工人机器直接将数据集存储在我不明白的 ParameterServerStrategy 中?

【问题讨论】:

  • (如果有人有同样的疑问)经过进一步研究,我发现,我们可以在存在参数服务器的 1 个服务器上启动协调器,我们可以使用 tf.distribute 启动工作程序和参数 ps。 Server(),它接受来自协调器的减少调用或训练调用。检查此链接tensorflow.org/api_docs/python/tf/distribute/Server

标签: python tensorflow distributed-computing distributed mlops


【解决方案1】:

为了社区的利益在这里回答。

来自评论部分:

(如果有人有同样的疑问)经过进一步研究,我发现,我们可以 在存在参数服务器的 1 台服务器上启动协调器,并且 我们可以使用启动workers和参数ps tf.distribute.Server(),接受reduce调用或训练调用 从协调员。检查这个链接 tensorflow.org/api_docs/python/tf/distribute/Server

【讨论】:

    最近更新 更多