【发布时间】:2018-09-11 22:23:28
【问题描述】:
我正在尝试使用 py_func 从函数中返回数据集,以便在 tensorflow 数据集管道/api 中使用。但是,py_func 会抛出错误:
TypeError: Expected DataType for argument 'Tout' not <class 'tensorflow.python.data.ops.dataset_ops.Dataset'>.
一个最小的例子如下:
import tensorflow as tf
import numpy as np
def fn(x, y):
a = tf.data.Dataset_from_tensors((x, y))
b = tf.data.Dataset_from_tensors((x, y))
return a.concatenate(b)
if __name__ == "__main__":
features = np.random.rand(5, 5, 5, 1)
labels = np.random.rand(5, 5)
dataset = tf.data.Dataset.from_tensors((features, labels))
dataset = dataset.flat_map(
lambda feature, label: tuple(tf.py_func(
fn, [feature, label], [tf.data.Dataset])))
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
val = sess.run(next_element)
这是 tensorflow 的错误,还是我错误地使用了 api?谢谢!
【问题讨论】:
-
我认为你用错了。
Tout: A list or tuple of tensorflow data types or a single tensorflow data type if there is only one, indicating what func returns. -
如果你指出它应该是`tf.data.Dataset`而不是
[tf.data.Dataset],两者都报同样的错误,否则我想我误解了你在说什么。 tf.data.Dataset 不是tensorflow data type,因为我的理解是 fn 应该返回的内容。
标签: python tensorflow tensorflow-datasets