【问题标题】:Tensorflow py_func TypeError with tf.data.Dataset output带有 tf.data.Dataset 输出的 Tensorflow py_func TypeError
【发布时间】:2018-09-11 22:23:28
【问题描述】:

我正在尝试使用 py_func 从函数中返回数据集,以便在 tensorflow 数据集管道/api 中使用。但是,py_func 会抛出错误:

TypeError: Expected DataType for argument 'Tout' not <class 'tensorflow.python.data.ops.dataset_ops.Dataset'>.

一个最小的例子如下:

import tensorflow as tf
import numpy as np


def fn(x, y):
    a = tf.data.Dataset_from_tensors((x, y))
    b = tf.data.Dataset_from_tensors((x, y))
    return a.concatenate(b)

if __name__ == "__main__":
    features = np.random.rand(5, 5, 5, 1)
    labels = np.random.rand(5, 5)

    dataset = tf.data.Dataset.from_tensors((features, labels))
    dataset = dataset.flat_map(
        lambda feature, label:  tuple(tf.py_func(
            fn, [feature, label], [tf.data.Dataset])))
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    sess = tf.Session()
    val = sess.run(next_element)

这是 tensorflow 的错误,还是我错误地使用了 api?谢谢!

【问题讨论】:

  • 我认为你用错了。 Tout: A list or tuple of tensorflow data types or a single tensorflow data type if there is only one, indicating what func returns.
  • 如果你指出它应该是`tf.data.Dataset`而不是[tf.data.Dataset],两者都报同样的错误,否则我想我误解了你在说什么。 tf.data.Dataset 不是 tensorflow data type,因为我的理解是 fn 应该返回的内容。

标签: python tensorflow tensorflow-datasets


【解决方案1】:

tf.py_func 使我们能够在等效的 Tensorflow API 不易获得时包装 python 代码。

我认为它不能返回 Tensorflow 类。

作为一个例子,我复制了我在另一个线程中发布的这段代码。此代码使用 Python API 来格式化从文件中读取的日期。并且返回的数据类型是 Tensorflow 数据类型。

tf.py_func 方便无缝地包含在数据管道中,但受API doc. 中提到的限制

import tensorflow as tf
from datetime import datetime


sess = tf.Session()

#Could be refactored
def convert_to_date(text):
    date = datetime.strptime(text.decode('ascii'), '%b %d %Y %I:%M%p')
    return date.strftime('%b %d %Y %I:%M%p')

filenames = ["C:/Machine Learning/text.txt"]

dataset = tf.data.Dataset.from_tensor_slices(filenames)

tf.data.TextLineDataset
dataset = dataset.flat_map(
                   lambda filename :
                   tf.data.TextLineDataset( filename ) ).map( lambda text :
                                                                        tf.py_func(convert_to_date,
                                                                                  [text],
                                                                                  [tf.string]))
iterator = dataset.make_one_shot_iterator()
date = iterator.get_next()


print(sess.run([date]))

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-02-18
    • 2016-06-03
    • 2018-11-09
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多