【发布时间】:2020-02-10 15:34:03
【问题描述】:
我用我自己的数据集中的 TFF 编写了一个代码,所有代码都可以正常运行,除了 这一行
在 train_data 中,我制作了 4 个数据集,加载了 tf.data.Dataset,它们的类型为“DatasetV1Adapter”
def client_data(n):
ds = source.create_tf_dataset_for_client(source.client_ids[n])
return ds.repeat(10).map(map_fn).shuffle(500).batch(20)
federated_train_data = [client_data(n) for n in range(4)]
batch = tf.nest.map_structure(lambda x: x.numpy(), iter(train_data[0]).next())
def model_fn():
model = tf.keras.models.Sequential([
.........
return tff.learning.from_compiled_keras_model(model, batch)
所有这些都运行正常,我得到了教练和状态:
trainer = tff.learning.build_federated_averaging_process(model_fn)
除了,当我想用这段代码开始训练和循环时:
state, metrics = iterative_process.next(state, federated_train_data)
print('round 1, metrics={}'.format(metrics))
我不能。错误来了!那么,错误可能来自哪里?从数据集的类型?还是我让数据联合的方式?
【问题讨论】:
-
能否将问题扩展到包含所看到的确切错误消息?从上面的代码来看,变量名称似乎不同:创建了一个
train_data,然后请求了一个federated_train_data。 -
感谢您的评论,我纠正了注意力不集中的错误。但我的问题是,在执行最后一行(第一轮)时,内核需要很长时间(运行),但随后它崩溃而没有显示任何内容。
-
听起来您的机器可能内存不足?很难说没有更多信息,例如输入数据集是什么样的,代码是在 CPU 上运行还是在 GPU 上运行?尝试减少内存占用的几件事:仅在单个客户端上运行进行测试(在构建
federated_train_data时将range(4)更改为range(1)),减少.shuffle()缓冲区的大小。在数据输入管道上在.map()之前调用.batch(),并使用num_parallel_calls参数,可以大大加快数据读取速度。 tensorflow.org/guide/data_performance 是一个很好的向导。 -
@ZacharyGarrett 请看答案,我做代码
标签: tensorflow2.0 tensorflow-federated