【问题标题】:Looping over tf.data.Dataset very slow循环 tf.data.Dataset 非常慢
【发布时间】:2020-06-18 15:39:49
【问题描述】:

我想知道为什么对 tf.data.Dataset 样本的 for 循环比在相应的 numpy 数组上循环慢得多。

import numpy as np
import tensorflow as tf
import time

a = np.ones(100000, dtype=np.float32)

start_time = time.time()
for x in a:
    pass
print(time.time() - start_time)

start_time = time.time()
for x in tf.data.Dataset.from_tensor_slices(a):
    pass
print(time.time() - start_time)

0.05548405647277832
5.67711615562439

我的 TensorFlow 版本是 2.0.0。

【问题讨论】:

    标签: python tensorflow tensorflow2.0 tensorflow-datasets


    【解决方案1】:

    是的,即使我也观察到相同的行为。要提高速度/性能,请尝试将 tf.data.dataset 包装在 @tf.function 中,这将花费几乎相同的时间。

    AutoGraphtf.function 中是默认设置,并将您的 Python Eager 代码转换为与图形兼容的 TensorFlow 操作。这包括控制流,如ifforwhile

    tf.function 最适用于 TensorFlow ops,NumPy 和 Python 调用被转换为常量。

    请参考下面显示的代码以包裹在@tf.function

    @tf.function
    
    def oper(a):
        start_time = time.time()
        for x in tf.data.Dataset.from_tensor_slices(a):
            pass
        print(time.time() - start_time)
    

    下面显示numpytf.data.dataset 性能之间的完整工作代码

    import numpy as np
    import tensorflow as tf
    import time
    
    a = np.ones(100000, dtype=np.float32)
    
    start_time = time.time()
    for x in a:
        pass
    print(time.time() - start_time)
    
    
    @tf.function
    
    def oper(a):
        start_time = time.time()
    
        for x in tf.data.Dataset.from_tensor_slices(a):
            pass
        print(time.time() - start_time)
    
    oper(a) 
    

    输出:

    0.012496232986450195
    0.017792224884033203
    

    要了解更多关于tf.function的信息,请参考this

    【讨论】:

    • 函数被跟踪的时候不是刚刚执行打印吗?
    猜你喜欢
    • 2022-01-22
    • 2012-09-20
    • 2011-08-09
    • 2018-11-18
    • 1970-01-01
    • 2012-10-28
    • 1970-01-01
    • 2016-09-07
    • 1970-01-01
    相关资源
    最近更新 更多