【问题标题】:Can't restore pre-trained network with Tensorflow无法使用 Tensorflow 恢复预训练网络
【发布时间】:2018-03-05 16:42:41
【问题描述】:

我一直在用 Tensorflow 恢复预训练的网络......

import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'

sess=tf.Session()
saver = tf.train.import_meta_graph('./model/20170512-110547/model-20170512-110547.meta')
saver.restore(sess,'./model/20170512-110547/')

我想使用经过训练用于人脸识别的预训练网络,然后想添加一些层用于迁移学习。 (我从这里下载了模型。https://github.com/davidsandberg/facenet

当我执行上面的代码时,它显示错误,

WARNING:tensorflow:The saved meta_graph is possibly from an older release:
'model_variables' collection should be of type 'byte_list', but instead is of type 'node_list'.
Traceback (most recent call last):
  File "/Users/user/Desktop/desktop/Python/HCR/Transfer_face/test.py", line 7, in <module>
    saver.restore(sess,'./model/20170512-110547/')
  File "/Users/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1560, in restore
    {self.saver_def.filename_tensor_name: save_path})
  File "/Users/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 895, in run
    run_metadata_ptr)
  File "/Users/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1124, in _run
    feed_dict_tensor, options, run_metadata)
  File "/Users/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1321, in _do_run
    options, run_metadata)
  File "/Users/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1340, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.NotFoundError: Unsuccessful TensorSliceReader constructor: Failed to find any matching files for ./model/20170512-110547/
     [[Node: save/RestoreV2_491 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_491/tensor_names, save/RestoreV2_491/shape_and_slices)]]

Caused by op u'save/RestoreV2_491', defined at:
  File "/Users/user/Desktop/desktop/Python/HCR/Transfer_face/test.py", line 6, in <module>
    saver = tf.train.import_meta_graph('./model/20170512-110547/model-20170512-110547.meta')
  File "/Users/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1698, in import_meta_graph
    **kwargs)
  File "/Users/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/meta_graph.py", line 656, in import_scoped_meta_graph
    producer_op_list=producer_op_list)
  File "/Users/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/importer.py", line 313, in import_graph_def
    op_def=op_def)
  File "/Users/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2630, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/Users/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1204, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

NotFoundError (see above for traceback): Unsuccessful TensorSliceReader constructor: Failed to find any matching files for ./model/20170512-110547/
     [[Node: save/RestoreV2_491 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_491/tensor_names, save/RestoreV2_491/shape_and_slices)]]

我不明白为什么系统找不到预训练数据... 目录结构如下

USER-no-MacBook-Pro:Transfer_face user$ ls -R

模型测试.py

./模型:

20170512-110547

./model/20170512-110547:

20170512-110547.pb

model-20170512-110547.ckpt-250000.index

model-20170512-110547.ckpt-250000.data-00000-of-00001

model-20170512-110547.meta

【问题讨论】:

  • 尝试旧版本的 tensorflow:The saved meta_graph is possibly from an older release。该模型是用 r0.12 构建的
  • 谢谢。我尝试了 0.12 和 1.2.0 版本(它写在需求中)。但仍然显示相同的错误....
  • 尝试在调用saver.restore()时将完整的绝对路径传递给模型目录(而不是相对路径'./model/20170512-110547/')。旧版本的 TensorFlow(包括 0.12,我认为)有一个错误,即它们不接受某些 API 中的相对路径,但这应该在最新版本中得到修复。

标签: python tensorflow deep-learning pre-trained-model


【解决方案1】:

导入 .pb 文件。

import tensorflow as tf
from tensorflow.python.framework import tensor_util

with tf.gfile.GFile('20170512-110547.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

#import into default graph
tf.import_graph_def(graph_def)

#print some data
wts = [n for n in graph_def.node if n.op == 'Const']

for n in wts:
    print(tensor_util.MakeNdarray(n.attr['value'].tensor))

相关问题:

Import a simple Tensorflow frozen_model.pb file and make prediction in C++

get the value weights from .pb file by Tensorflow

相关文档:GraphDef

【讨论】:

    【解决方案2】:

    您需要使用 ckpt 路径“./model/20170512-110547/model-20170512-110547.ckpt-250000”而不是文件夹路径。

    【讨论】:

      猜你喜欢
      • 2020-08-15
      • 2017-05-23
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2016-05-01
      • 1970-01-01
      • 2020-12-07
      • 2018-07-28
      相关资源
      最近更新 更多