【问题标题】:how to load and use a saved model on tensorflow?如何在 tensorflow 上加载和使用保存的模型?
【发布时间】:2020-05-19 00:54:25
【问题描述】:

我找到了两种在 Tensorflow 中保存模型的方法:tf.train.Saver()SavedModelBuilder。但是,在以第二种方式加载模型后,我找不到有关使用模型的文档

注意:我想使用SavedModelBuilder 方式,因为我在 Python 中训练模型,并将在服务时使用另一种语言 (Go),而在这种情况下,SavedModelBuilder 似乎是唯一的方式。

这对tf.train.Saver() 非常有效(第一种方式):

model = tf.add(W * x, b, name="finalnode")

# save
saver = tf.train.Saver()
saver.save(sess, "/tmp/model")

# load
saver.restore(sess, "/tmp/model")

# IMPORTANT PART: REALLY USING THE MODEL AFTER LOADING IT
# I CAN'T FIND AN EQUIVALENT OF THIS PART IN THE OTHER WAY.

model = graph.get_tensor_by_name("finalnode:0")
sess.run(model, {x: [5, 6, 7]})

tf.saved_model.builder.SavedModelBuilder()Readme 中定义,但在使用tf.saved_model.loader.load(sess, [], export_dir) 加载模型后,我找不到返回节点的文档(请参阅上面代码中的"finalnode"

【问题讨论】:

  • 注意:这个函数只能通过 v1 兼容库作为tf.compat.v1.saved_model.builder.SavedModelBuildertf.compat.v1.saved_model.Builder 使用。 Tensorflow 2.0 将引入一种基于对象的新方法来创建 SavedModel。

标签: tensorflow


【解决方案1】:

缺少的是signature

# Saving
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(sess, ["tag"], signature_def_map= {
        "model": tf.saved_model.signature_def_utils.predict_signature_def(
            inputs= {"x": x},
            outputs= {"finalnode": model})
        })
builder.save()

# loading
with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(sess, ["tag"], export_dir)
    graph = tf.get_default_graph()
    x = graph.get_tensor_by_name("x:0")
    model = graph.get_tensor_by_name("finalnode:0")
    print(sess.run(model, {x: [5, 6, 7, 8]}))

【讨论】:

  • 这在 2.0 中如何工作? tf.saved_model.builder 不包含在新版本中
  • @jregalad 你想阅读保存的模型吗?如果是,请分享您如何在 2.0+ 版本中保存保存的模型。
【解决方案2】:

这是使用 simple_save 加载和恢复/预测模型的代码 sn-p

#Save the model:
tf.saved_model.simple_save(sess, export_dir=saveModelPath,
                                   inputs={"inputImageBatch": X_train, "inputClassBatch": Y_train,
                                           "isTrainingBool": isTraining},
                                   outputs={"predictedClassBatch": predClass})

请注意,使用 simple_save 会设置某些默认值(可以在以下位置查看:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/simple_save.py

现在,恢复和使用输入/输出字典:

from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import signature_constants

with tf.Session() as sess:
  model = tf.saved_model.loader.load(export_dir=saveModelPath, sess=sess, tags=[tag_constants.SERVING]) #Note the SERVINGS tag is put as default.

  inputImage_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['inputImageBatch'].name
  inputImage = tf.get_default_graph().get_tensor_by_name(inputImage_name)

  inputLabel_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['inputClassBatch'].name
  inputLabel = tf.get_default_graph().get_tensor_by_name(inputLabel_name)

  isTraining_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['isTrainingBool'].name
  isTraining = tf.get_default_graph().get_tensor_by_name(isTraining_name)

  outputPrediction_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs['predictedClassBatch'].name
  outputPrediction = tf.get_default_graph().get_tensor_by_name(outputPrediction_name)

  outPred = sess.run(outputPrediction, feed_dict={inputImage:sampleImages, isTraining:False})

  print("predicted classes:", outPred)

注意:需要默认的 signature_def 才能使用输入和输出字典中指定的张量名称。

【讨论】:

    【解决方案3】:

    Tensorflow 构建和使用不同语言模型的首选方式是tensorflow serving

    现在,在您的情况下,您正在使用 saver.save 来保存模型。这样它会保存一个meta 文件、ckpt 文件和一些其他文件来保存权重和网络信息、训练的步数等。这是在训练时保存的首选方式。

    如果您现在完成了训练,您应该使用SavedModelBuilder 从您保存的由saver.save 保存的文件中冻结图形。此冻结图包含一个pb 文件,并包含所有网络和权重。

    这个冻结的模型应该被tensorflow serving使用,然后其他语言可以使用gRPC协议的模型。

    整个过程在this优秀教程中有描述。

    【讨论】:

    • 感谢您的回答和链接,但这并不能回答我的问题...
    • 链接确实在“最后一步 — 保存模型”之后的某处有答案,但只有在您已经知道在哪里查找时才容易找到...它绝对可以更简洁,但也感谢您的链接和见解
    【解决方案4】:

    一个代码 sn-p 可用于加载 pb 文件并在单个图像上进行推理。

    代码遵循以下步骤:将 pb 文件加载到 GraphDef(图的序列化版本(用于读取 pb 文件)中,将 GraphDef 加载到 Graph 中,通过名称获取输入和输出张量,推断 a单张图片。

    import tensorflow as tf 
    import numpy as np
    import cv2
    
    INPUT_TENSOR_NAME = 'input_tensor_name:0'
    OUTPUT_TENSOR_NAME = 'output_tensor_name:0'
    
    # Read image, get shape
    # Add dimension to fit batch shape
    img = cv2.imread(IMAGE_PATH)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    image = img.astype(float)
    height, width, channels = image.shape
    image = np.expand_dims(image, 0)  # Add dimension (to fit batch shape)
    
    
    # Read pb file into the graph as GraphDef - Serialized version of a graph     (used to read pb files)
    with tf.gfile.FastGFile(PB_PATH, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    
    # Load GraphDef into Graph
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="")
    
    # Get tensors (input and output) by name
    input_tensor = graph.get_tensor_by_name(INPUT_TENSOR_NAME)
    output_tensor = graph.get_tensor_by_name(OUTPUT_TENSOR_NAME)
    
    # Inference on single image
    with tf.Session(graph=graph) as sess:
        output_vals = sess.run(output_tensor, feed_dict={input_tensor: image})  #
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-04-12
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多