【发布时间】:2026-01-06 22:40:01
【问题描述】:
在来自tensorflow API的ParameterServerTraining的教程代码中,在model.fit部分有如下代码sn-p
def dataset_fn(input_context):
global_batch_size = 64
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
x = tf.random.uniform((10, 10))
y = tf.random.uniform((10,))
dataset = tf.data.Dataset.from_tensor_slices((x, y)).shuffle(10).repeat()
dataset = dataset.shard(
input_context.num_input_pipelines,
input_context.input_pipeline_id)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(2)
return dataset
dc = tf.keras.utils.experimental.DatasetCreator(dataset_fn)
也有人说
The code in dataset_fn will be invoked on the input device, which is usually the CPU, on each of the worker machines.
这是否意味着数据集必须在每个工作服务器的相同存储上(比如参数服务器和工作服务器是不同的机器)?
或者有没有什么方法可以让一台机器上的参数服务器将训练数据发送给工人,而不需要工人机器直接将数据集存储在我不明白的 ParameterServerStrategy 中?
【问题讨论】:
-
(如果有人有同样的疑问)经过进一步研究,我发现,我们可以在存在参数服务器的 1 个服务器上启动协调器,我们可以使用 tf.distribute 启动工作程序和参数 ps。 Server(),它接受来自协调器的减少调用或训练调用。检查此链接tensorflow.org/api_docs/python/tf/distribute/Server
标签: python tensorflow distributed-computing distributed mlops