【问题标题】:How to get class predictions from a trained Tensorflow classifier?如何从训练有素的 Tensorflow 分类器中获得类预测?
【发布时间】:2018-03-15 15:21:06
【问题描述】:

我已经训练了二元分类器模型。模型类包含self.costself.initial_stateself.final_stateself.logits 参数。用tf.train.Saver简单保存:

saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
saver.save(session, 'model.ckpt')

模型训练完成后,我将其加载为:

with tf.variable_scope("Model", reuse=False):
    model = MODEL(config, is_training=False)

with tf.Session() as session:
    saver = tf.train.Saver(tf.global_variables())
    saver.restore(session, 'model.ckpt')

但是,我的 model.run 函数返回交叉熵损失,这是图中的最后一个操作。我不需要损失,我需要每个批次元素的模型预测

logits = tf.sigmoid(tf.nn.xw_plus_b(last_layer, self.output_w, self.output_b))

其中last_layer 是一个800x1 矩阵,然后我将其重塑为32x25x1 (batch_size, sequence_length, 1) 矩阵。正是这个矩阵包含了 [0-1] 范围内的模型预测值。

那么,我如何使用这个模型来预测单元素矩阵1x1x1

【问题讨论】:

    标签: python tensorflow machine-learning


    【解决方案1】:

    添加计算准确性所需的 OP,类似于我在下面复制的内容(只是从我手头最接近的模型中复制出来的)。

      self.logits_flat = tf.argmax(logits, axis=1, output_type=tf.int32)
      labels_flat = tf.argmax(labels, axis=1, output_type=tf.int32)
      accuracy = tf.cast(tf.equal(self.logits_flat, labels_flat), tf.float32, name='accuracy')
    

    现在,当您运行模型(在测试或训练期间)时,将准确度添加到 sess.run 调用中:

    sess.run([train_op, accuracy], feed_dict=...)
    

    sess.run([accuracy, logits], feed_dict=...)
    

    当您调用 sess.run 时,您所做的只是告诉 tensorflow 计算您所要求的任何值的值。您需要将其传递给执行这些计算所需的任何数据。 Tensorflow 是惰性的,它不会执行任何非明确必要的计算来产生您请求的结果。例如。如果您运行上面列出的 sess.run 的第二个版本,优化器将不会运行,因此您的权重将不会更新。

    请注意,您可以在网络训练后添加 OP,因为它们实际上都没有添加任何变量,因此它们不会影响保存/恢复过程。

    【讨论】:

    • 我明白了。如果我不需要计算准确度而只需要每个输入标签的预测分数怎么办? session.run({"logits": self.logits}, feed_dict) 会起作用还是我必须通过 train_op 才能启用顺序依赖?
    • 完全正确。在这种情况下,张量流的魔力将找出计算logits 所需的最小必要依赖项并返回它。事实上,你不能请求train_op,如果你这样做,它会更新你的权重,而不是你想要的。请注意如何使用相同的图表,甚至相同的会话,在训练完成后执行推理,只需更改您从 sess.run 请求的 OP。
    • 另外请注意,如果您只请求logits,则不需要传入labels 占位符,因为这不是仅计算logits 的一部分,而是计算的一部分accuracy 如果你要求的话。
    猜你喜欢
    • 2020-01-30
    • 2018-10-13
    • 1970-01-01
    • 2018-11-03
    • 2019-06-07
    • 2019-10-10
    • 2017-08-20
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多