【问题标题】:TensorFlow: How to predict from a SavedModel?TensorFlow:如何从 SavedModel 进行预测?
【发布时间】:2018-02-04 15:22:18
【问题描述】:

我已经导出了一个SavedModel,现在我可以重新加载它并进行预测。它使用以下特征和标签进行了训练:

F1 : FLOAT32
F2 : FLOAT32
F3 : FLOAT32
L1 : FLOAT32

假设我想输入值 20.9, 1.8, 0.9 得到一个 FLOAT32 预测。我该如何做到这一点?我已成功加载模型,但我不确定如何访问它以进行预测调用。

with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(
        sess,
        [tf.saved_model.tag_constants.SERVING],
        "/job/export/Servo/1503723455"
    )

    # How can I predict from here?
    # I want to do something like prediction = model.predict([20.9, 1.8, 0.9])

此问题与here 发布的问题不重复。这个问题的重点是对任何模型类的SavedModel 执行推理的最小示例(不仅限于tf.estimator)以及指定输入和输出节点名称的语法。

【问题讨论】:

标签: python machine-learning tensorflow tensorflow-serving


【解决方案1】:

加载图表后,它在当前上下文中可用,您可以通过它提供输入数据以获得预测。每个用例都有很大的不同,但添加到您的代码中的内容如下所示:

with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(
        sess,
        [tf.saved_model.tag_constants.SERVING],
        "/job/export/Servo/1503723455"
    )

    prediction = sess.run(
        'prefix/predictions/Identity:0',
        feed_dict={
            'Placeholder:0': [20.9],
            'Placeholder_1:0': [1.8],
            'Placeholder_2:0': [0.9]
        }
    )

    print(prediction)

在这里,您需要知道预测输入的名称。如果您没有在您的serving_fn 中给他们一个中殿,那么他们默认为Placeholder_n,其中n 是第n 个功能。

sess.run 的第一个字符串参数是预测目标的名称。这将根据您的用例而有所不同。

【讨论】:

  • 为什么我们不能将多个值传递给每个占位符,比如批处理或输入? prediction = sess.run( 'prefix/predictions/Identity:0', feed_dict={ 'Placeholder:0': [20.9, 11.3], 'Placeholder_1:0': [1.8, 2.6], 'Placeholder_2:0': [0.9, 0.76] } )
【解决方案2】:

假设您想要在 Python 中进行预测,SavedModelPredictor 可能是加载 SavedModel 并获取预测的最简单方法。假设您像这样保存模型:

# Build the graph
f1 = tf.placeholder(shape=[], dtype=tf.float32)
f2 = tf.placeholder(shape=[], dtype=tf.float32)
f3 = tf.placeholder(shape=[], dtype=tf.float32)
l1 = tf.placeholder(shape=[], dtype=tf.float32)
output = build_graph(f1, f2, f3, l1)

# Save the model
inputs = {'F1': f1, 'F2': f2, 'F3': f3, 'L1': l1}
outputs = {'output': output_tensor}
tf.contrib.simple_save(sess, export_dir, inputs, outputs)

(输入可以是任何形状,甚至不必是图中的占位符或根节点)。

然后,在将使用SavedModel 的 Python 程序中,我们可以得到如下预测:

from tensorflow.contrib import predictor

predict_fn = predictor.from_saved_model(export_dir)
predictions = predict_fn(
    {"F1": 1.0, "F2": 2.0, "F3": 3.0, "L1": 4.0})
print(predictions)

This answer 展示了如何在 Java、C++ 和 Python 中获得预测(尽管 问题 关注于 Estimators,但答案实际上与 SavedModel 的创建方式无关)。

【讨论】:

  • 当使用 tf.data.Dataset 及其迭代器从输入文件中读取模型的输入时,显然 simple_save 与图形构建代码不兼容,因为 simple_save 需要张量而不是 numpy 数组。此外,命名空间更改为 tf.saved_model 而不是 tf.contrib。我的代码可能是问题所在。使用使用 Dataset 训练并使用 simple_save 保存的模型的已知工作代码示例将非常适合 @mrry。
  • 显然 simple_save 的另一个不兼容之处在于输入稀疏 numpy 数组的图形制作代码,而图形制作代码的第一行是 tf.stack,因为它是一个稀疏矩阵。那么,在图形构建代码之外,您将 tf.stack 放在哪里?使用模型调用 tf.stack 并使用 simple_save 保存的已知工作代码示例将非常适合 @mrry
  • @GeoffreyAnderson 听起来值得自己提出问题;请务必发布您正在使用的代码的 sn-ps。
  • 请注意,这在 TensorFlow 2.0 中可能不再适用。至少,导入可能已经改变。
【解决方案3】:

对于任何需要保存经过训练的罐装模型并在没有 tensorflow 服务的情况下为其提供工作示例的人,我已在此处记录 https://github.com/tettusud/tensorflow-examples/tree/master/estimators

  1. 您可以从tf.tensorflow.contrib.predictor.from_saved_model( exported_model_path) 创建预测器
  2. 准备输入

    tf.train.Example( 
        features= tf.train.Features(
            feature={
                'x': tf.train.Feature(
                     float_list=tf.train.FloatList(value=[6.4, 3.2, 4.5, 1.5])
                )     
            }
        )    
    )
    

这里x 是导出时在 input_receiver_function 中给出的输入名称。 例如:

feature_spec = {'x': tf.FixedLenFeature([4],tf.float32)}

def serving_input_receiver_fn():
    serialized_tf_example = tf.placeholder(dtype=tf.string,
                                           shape=[None],
                                           name='input_tensors')
    receiver_tensors = {'inputs': serialized_tf_example}
    features = tf.parse_example(serialized_tf_example, feature_spec)
    return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

【讨论】:

    【解决方案4】:

    tf.estimator.DNNClassifier 的构造函数有一个名为 warm_start_from 的参数。你可以给它SavedModel 文件夹名称,它会恢复你的会话。

    【讨论】:

    • warm_start_from 文件夹包含检查点或SavedModel?它们不一样。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2020-01-13
    • 2020-08-18
    • 1970-01-01
    • 1970-01-01
    • 2021-01-07
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多