【发布时间】:2020-04-01 16:15:08
【问题描述】:
我正在 Tensorflow 上构建一个用于图像去模糊的 GAN,它是 DeblurGANv2 的一个实现。我将 GAN 设置为它有两个输入、一批模糊图像和一批清晰图像。按照这一行,我将输入设计为带有两个键['sharp', 'blur'] 的 Python 字典,每个键都有一个形状为 [batch_size, 512, 512, 3] 的张量,这样可以轻松地将模糊图像批量输入生成器,然后输入输出生成器和清晰的图像批处理到鉴别器。
根据最后的要求,我创建了一个tf.data.Dataset,它准确地输出了一个包含两个张量的字典,每个张量都有它们的批次维度。这与我的 GAN 实现完美互补,一切正常且顺利。
所以请记住,我的输入不是张量,而是一个没有批量维度的python dict,这将与稍后解释我的问题有关。
最近,我决定使用 TensorFlow 分布策略添加对分布式训练的支持。 Tensorflow 的这一特性允许将训练分布在多个设备上,包括在多台机器上。一些实现有一个特性,例如MirroredStrategy,它接受输入张量,将其分成相等的部分,并将每个切片提供给不同的设备,这意味着,如果你有 16 个和 4 个 GPU 的批量大小,每个 GPU 将结束本地批处理 4 个数据点,在此之后,聚合结果和其他与我的问题无关的东西有一些魔力。
正如您已经注意到的,对于分配策略而言,将张量作为输入或至少某种具有外部批处理维度的输入至关重要,而我拥有的是 Python 字典,输入的批处理维度在内部字典张量值。这是一个大问题,我目前的实现与分布式训练不兼容。
我一直在寻找解决方法,但我无法很好地解决这个问题,也许只是将输入设为 shape=[batch_size, 2, 512, 512, 3] 的巨大张量并将其切片?不知道这只是我现在想到的,哈哈。无论如何,我认为这非常模棱两可,我无法区分这两个输入,至少不是字典键的清晰度。编辑:这个解决方案的问题是使我的数据集转换非常昂贵,因此使数据集吞吐量变慢,考虑到这是一个图像加载管道,这是一个重点。
也许我对分布式策略如何工作的解释不是最严格的,如果我没有看到任何内容,请随时纠正我。
PD:这不是bug问题或代码错误,主要是“系统设计查询”,希望这里不违法
【问题讨论】:
标签: python tensorflow tensorflow2.0 tensorflow-datasets