【发布时间】:2019-03-04 22:25:54
【问题描述】:
我有一个 TensorFlow 模型,它使用 tf.data.Dataset feedable 迭代器在训练和验证之间切换。两个数据集共享相同的结构,即它们具有特征矩阵和相应的标签向量。为了使用相同的模型和迭代器进行推理(没有标签向量只有特征矩阵),我需要理想地提供一个零标签向量。是否有更高效、更优雅的方式将数据集 API 用于训练(验证)和推理?
在代码中:
training_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
validation_dataset = tf.data.Dataset.from_tensor_slices((X_validation, y_validation))
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, training_dataset.output_types, training_dataset.output_shapes)
features, labels = iterator.get_next()
特征和标签在模型内部用作输入占位符。 为了在数据集之间切换,我需要为每个数据集创建一个迭代器:
training_iterator = training_dataset.make_initializable_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()
然后创建句柄
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
并使用handle 选择要使用的数据集,例如:
sess.run(next_element, feed_dict={handle: training_handle})
现在,如果我有没有标签的推理数据会怎样?
inference_dataset = tf.data.Dataset.from_tensor_slices(X_inference) # NO y values
inferece_iterator = inference_dataset.make_initializable_iterator()
如果我添加这个迭代器,它会抛出异常,因为“组件数不匹配:预期 2 种类型但得到 1 种。” 有什么建议吗?
这个帖子How to use tf.Dataset design in both training and inferring?和这个问题有关,但是tf.data.Dataset没有解压方法。
解决此问题的最佳做法是什么?
【问题讨论】:
标签: python tensorflow tensorflow-datasets