【问题标题】:tensorflow dataset tf.estimator.inputs.numpy_input_fn张量流数据集 tf.estimator.inputs.numpy_input_fn
【发布时间】:2018-08-04 00:21:08
【问题描述】:

我正在编写一个代码,用于在 tensorflow 中从磁盘读取图像和标签,然后尝试调用 tf.estimator.inputs.numpy_input_fn。如何传递整个数据集而不是单个图像。我的代码如下所示:

filenames = tf.constant(filenames)
labels = tf.constant(labels)

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)
dataset_batched = dataset.batch(10)
iterator = dataset_batched.make_one_shot_iterator()
features, labels = iterator.get_next()

with tf.Session() as sess:

  print(dataset_batched)
  print(np.shape(sess.run(features)))
  print(np.shape(sess.run(labels)))

  mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_mk, model_dir=dir)
  train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": np.array(sess.run(features))},
                                                  y=np.array(sess.run(labels)),
                                                  batch_size=1,
                                                  num_epochs=None,
                                                  shuffle=False)
  mnist_classifier.train(input_fn=train_input_fn, steps=1)

我的问题是如何在这里传递数据集x={"x": np.array(sess.run(features))}

【问题讨论】:

    标签: python tensorflow dataset


    【解决方案1】:

    这里不需要/使用numpy_input_fn。您应该将顶部的代码包装到一个返回iterator.get_next() 的函数(例如my_input_fn)中,然后将input_fn=my_input_fn 传递给train 调用。这会将完整的数据集以 10 个批次传递给训练代码。

    numpy_input_fn 适用于当您已经在数组中拥有完整的数据集并且想要快速进行批处理/改组/重复等操作时。

    【讨论】:

    • 我试图从我的input_fn 返回iterator.get_next(),但出现错误提示:ValueError: input_fn must return a tf.data.Dataset or a callable。你能帮忙吗?
    • 请发布一个单独的问题,并附上一个可重复的最小示例。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-12-27
    • 2021-01-30
    • 2020-08-16
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多