【发布时间】:2018-11-27 18:09:58
【问题描述】:
我正在使用 TensorFlow 和 tf.data.Dataset API 来执行一些文本预处理。在我的dataset.map 调用中不使用num_parallel_calls,预处理10K 记录需要0.03 秒。
当我使用num_parallel_trials=8(我机器上的核心数)时,预处理10K记录也需要0.03s。
我搜索了一下,发现了这个:Parallelism isn't reducing the time in dataset map
它们表明您需要使用 TensorFlow 操作来查看加速。事情是这样的:我 am 只使用 TensorFlow 操作。具体来说,我正在映射这个函数:
def preprocess(self, x, data_table):
x['reviews'] = tf.string_split(x['reviews'], delimiter=' ')
x['reviews'] = tf.sparse_tensor_to_dense(x['reviews'], default_value=' ')
x['reviews'] = tf.cast(data_table.lookup(x['reviews']), tf.int32)
nbatch = tf.cast(tf.shape(x['reviews'])[0], tf.int32)
nseq = tf.cast(tf.shape(x['reviews'])[1], tf.int32)
padding = tf.cond(tf.less(nseq, 100),
lambda: 0 * tf.ones([nbatch, 100 - nseq], tf.int32),
lambda: 0 * tf.ones([nbatch, 0], tf.int32))
x['reviews'] = tf.concat((x['reviews'], padding), axis=1)[:, :100]
x['reviews'].set_shape([None, 100])
return x
知道为什么我没有看到任何加速吗?
谢谢!
【问题讨论】:
-
可能有很多原因,但我看到
sparse_tensor_to_dense和lambda这两个操作是这里的瓶颈。但要进一步调查,您应该提供更多详细信息,以及您想要实现的目标以及数据集和管道的外观 -
preprocess()是您传递给Dataset.map()的函数吗? (问是因为我不希望data_table成为 map 函数中的参数。)正如 mlRocks 所建议的,查看输入管道的更大上下文会很有帮助。例如,如果您的输入数据在一个缓慢的存储系统上,您可能会遇到 I/O 瓶颈,map()中的并行性将无法恢复。
标签: python tensorflow tensorflow-datasets