【发布时间】:2018-08-04 00:21:08
【问题描述】:
我正在编写一个代码,用于在 tensorflow 中从磁盘读取图像和标签,然后尝试调用 tf.estimator.inputs.numpy_input_fn。如何传递整个数据集而不是单个图像。我的代码如下所示:
filenames = tf.constant(filenames)
labels = tf.constant(labels)
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)
dataset_batched = dataset.batch(10)
iterator = dataset_batched.make_one_shot_iterator()
features, labels = iterator.get_next()
with tf.Session() as sess:
print(dataset_batched)
print(np.shape(sess.run(features)))
print(np.shape(sess.run(labels)))
mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_mk, model_dir=dir)
train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": np.array(sess.run(features))},
y=np.array(sess.run(labels)),
batch_size=1,
num_epochs=None,
shuffle=False)
mnist_classifier.train(input_fn=train_input_fn, steps=1)
我的问题是如何在这里传递数据集x={"x": np.array(sess.run(features))}
【问题讨论】:
标签: python tensorflow dataset