【发布时间】:2017-03-20 00:29:36
【问题描述】:
我有一个 ckpt 文件。我只想得到cnn的权重 我是从 ckpt 检查点文件中训练出来的。? inception_resnet_v2_2016_08_30
import tensorflow as tf
saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, "inception_resnet_v2_2016_08_30.ckpt")
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.platform import app
from tensorflow.python.platform import gfile
from tensorflow.python.training import saver as saver_lib
with session.Session() as sess:
var_list = {}
reader =pywrap_tensorflow.NewCheckpointReader("./inception_resnet_v2_2016_08_30.ckpt")
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
try:
tensor = sess.graph.get_tensor_by_name(key + ":0")
except KeyError:
continue
var_list[key] = tensor
saver = saver_lib.Saver(var_list=var_list)
saver.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes)
【问题讨论】:
-
您问题中的代码无法正常工作,原因与
saver.restore()不起作用的原因相同:没有可将张量恢复到其中的图形。加载检查点值后,您想对其执行什么操作?您可以调用reader.get_tensor(key)以将检查点值作为 NumPy 数组获取。您可能应该更改for循环的主体来执行此操作。
标签: python-3.x tensorflow deep-learning