【发布时间】:2020-05-19 00:54:25
【问题描述】:
我找到了两种在 Tensorflow 中保存模型的方法:tf.train.Saver() 和 SavedModelBuilder。但是,在以第二种方式加载模型后,我找不到有关使用模型的文档。
注意:我想使用SavedModelBuilder 方式,因为我在 Python 中训练模型,并将在服务时使用另一种语言 (Go),而在这种情况下,SavedModelBuilder 似乎是唯一的方式。
这对tf.train.Saver() 非常有效(第一种方式):
model = tf.add(W * x, b, name="finalnode")
# save
saver = tf.train.Saver()
saver.save(sess, "/tmp/model")
# load
saver.restore(sess, "/tmp/model")
# IMPORTANT PART: REALLY USING THE MODEL AFTER LOADING IT
# I CAN'T FIND AN EQUIVALENT OF THIS PART IN THE OTHER WAY.
model = graph.get_tensor_by_name("finalnode:0")
sess.run(model, {x: [5, 6, 7]})
tf.saved_model.builder.SavedModelBuilder() 在Readme 中定义,但在使用tf.saved_model.loader.load(sess, [], export_dir) 加载模型后,我找不到返回节点的文档(请参阅上面代码中的"finalnode")
【问题讨论】:
-
注意:这个函数只能通过 v1 兼容库作为
tf.compat.v1.saved_model.builder.SavedModelBuilder或tf.compat.v1.saved_model.Builder使用。 Tensorflow 2.0 将引入一种基于对象的新方法来创建 SavedModel。
标签: tensorflow