【发布时间】:2018-10-01 00:18:21
【问题描述】:
我正在尝试对我拥有的数据集 (X_train) 进行推理,并在没有将 softmax 应用于输出的情况下获取 logits 层的值。我从检查点文件 (model_X.ckpt) 加载的模型具有名为“logits”的 logits 层。所以基本上,我想跑:
sess.run("model_X/logits:0", feed_dict: {"Placeholder:0": X_train, keep_prob:1.0})
但是模型将输入数据集的大小限制为 32,这不允许我一次传递 10,000 个输入。这就是为什么我也使用批量创建的原因:
features_placeholder = tf.placeholder(X_train.dtype, X_train.shape)
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder))
dataset = dataset.batch(32)
我创建了一个迭代器(完全基于我的推理,使用 one-shot-iterator 会扩大我的图形大小):
def initialize_iterator(sess, iterator, features):
sess.run(iterator.initializer, feed_dict={features_placeholder: features})
iterator = dataset.make_initializable_iterator()
initialize_iterator(sess, iterator, X_train)
next_x = iterator.get_next()
# Assign the first batch:
val = sess.run(next_x)
layer = "model_X/logits:0"
units = sess.run(layer,feed_dict={"Placeholder:0": val, keep_prob:1.0})
如何遍历所有批次以推断所有输入?
【问题讨论】:
标签: python tensorflow inference