【问题标题】:Error on restoring Tensorflow checkpoint file恢复 TensorFlow 检查点文件时出错
【发布时间】:2015-12-16 16:39:47
【问题描述】:

在张量流中使用saver.restore() 方法时出现以下错误。知道为什么会这样吗?

我这样保存模型: saver.save(sess, checkpoint_path, global_step=step)

错误是:

tensorflow.python.framework.errors.InvalidArgumentError: Node 'Variable_1/Assign': Unknown input node Variable_1
     [[Node: Variable_1/initial_value = Const[dtype=DT_FLOAT, value=Tensor<type: float shape: [] values: 0.9>]()]]

完整的跟踪:

can't determine number of CPU cores: assuming 4
I tensorflow/core/common_runtime/local_device.cc:25] Local device intra op parallelism threads: 4
can't determine number of CPU cores: assuming 4
I tensorflow/core/common_runtime/local_session.cc:45] Local session inter op parallelism threads: 4
('1.1- label batch shape is ', TensorShape([Dimension(128)]))
Inferencing
('in inferemcee ', TensorShape([Dimension(128), Dimension(3072)]), <class 'tensorflow.python.framework.ops.Tensor'>)
Evaluation..
tmp/ckpt/model.ckpt-9100
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc789748be0 Compute status: Cancelled: Enqueue operation was cancelled
     [[Node: input/string_input_producer/string_input_producer_EnqueueMany = QueueEnqueueMany[Tcomponents=[DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/string_input_producer, input/string_input_producer/limit_epochs)]]
I tensorflow/core/kernels/fifo_queue.cc:154] Skipping cancelled enqueue attempt
Traceback (most recent call last):
  File "/ProjectS/Cifar-Eval/my_eval.py", line 112, in <module>
    tf.app.run()
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc78b939670 Compute status: Cancelled: Enqueue operation was cancelled
     [[Node: input/batching_shuffling/random_shuffle_queue_enqueue = QueueEnqueue[Tcomponents=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/batching_shuffling/random_shuffle_queue, input/sub, input/Cast_2)]]
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/platform/default/_app.py", line 11, in run
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc78954f080 Compute status: Cancelled: Enqueue operation was cancelled
     [[Node: input/batching_shuffling/random_shuffle_queue_enqueue = QueueEnqueue[Tcomponents=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/batching_shuffling/random_shuffle_queue, input/sub, input/Cast_2)]]
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc78954e5d0 Compute status: Cancelled: Enqueue operation was cancelled
     [[Node: input/batching_shuffling/random_shuffle_queue_enqueue = QueueEnqueue[Tcomponents=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/batching_shuffling/random_shuffle_queue, input/sub, input/Cast_2)]]
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc789550370 Compute status: Cancelled: Enqueue operation was cancelled
     [[Node: input/batching_shuffling/random_shuffle_queue_enqueue = QueueEnqueue[Tcomponents=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/batching_shuffling/random_shuffle_queue, input/sub, input/Cast_2)]]
W tensorflow/core/common_runtime/executor.cc:1027] 0x7fc78ba28cb0 Compute status: Cancelled: Enqueue operation was cancelled
     [[Node: input/batching_shuffling/random_shuffle_queue_enqueue = QueueEnqueue[Tcomponents=[DT_FLOAT, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input/batching_shuffling/random_shuffle_queue, input/sub, input/Cast_2)]]
    sys.exit(main(sys.argv))
  File "/ProjectS/Cifar-Eval/my_eval.py", line 108, in main
    my_eval()
  File "/ProjectS/Cifar-Eval/my_eval.py", line 85, in my_eval
    saver.restore(sess, ckpt.model_checkpoint_path)
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 864, in restore
    sess.run([self._restore_op_name], {self._filename_tensor_name: save_path})
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 345, in run
    results = self._do_run(target_list, unique_fetch_targets, feed_dict_string)
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 419, in _do_run
    e.code)
tensorflow.python.framework.errors.InvalidArgumentError: Node 'Variable_1/Assign': Unknown input node Variable_1
     [[Node: Reshape/shape = Const[dtype=DT_INT32, value=Tensor<type: int32 shape: [4] values: -1 32 32...>]()]]
Caused by op u'Reshape/shape', defined at:
  File "/ProjectS/Cifar-Eval/my_eval.py", line 112, in <module>
    tf.app.run()
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/platform/default/_app.py", line 11, in run
    sys.exit(main(sys.argv))
  File "/ProjectS/Cifar-Eval/my_eval.py", line 108, in main
    my_eval()
  File "/ProjectS/Cifar-Eval/my_eval.py", line 78, in my_eval
    logits = my_cifar.inference(images_placeholder)
  File "/ProjectS/Cifar-Eval/my_cifar.py", line 68, in inference
    images = tf.reshape(images, shape=[-1, 32, 32, 3])
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 554, in reshape
    name=name)
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/op_def_library.py", line 397, in apply_op
    values, name=input_arg.name, dtype=dtype)
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 468, in convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name)
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/constant_op.py", line 147, in constant
    attrs={"value": tensor_value, "dtype": dtype_value}, name=name).outputs[0]
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1710, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/Users/user/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 988, in __init__
    self._traceback = _extract_stack()

我的检查点文件恢复代码

import tensorflow as tf

import my_cifar
import my_input

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('eval_dir', 'tmp/log_eval',
                           """Directory where to write event logs.""")

tf.app.flags.DEFINE_string('checkpoint_dir', 'tmp/ckpt',
                           """Directory where to read model checkpoints.""")


IMAGE_PIXELS = 32 * 32 * 3


def placeholder_inputs(batch_size):
  """Generate placeholder variables to represent the the input tensors.
  These placeholders are used as inputs by the rest of the model building
  code and will be fed from the downloaded ckpt in the .run() loop, below.
  Args:
    batch_size: The batch size will be baked into both placeholders.
  Returns:
    images_placeholder: Images placeholder.
    labels_placeholder: Labels placeholder.
  """
  # Note that the shapes of the placeholders match the shapes of the full
  # image and label tensors, except the first dimension is now batch_size
  # rather than the full size of the train or test ckpt sets.
  # batch_size = -1
  images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
                                                         IMAGE_PIXELS))
  # 32, 32, 3))
  labels_placeholder = tf.placeholder(tf.int32, shape=batch_size)

  return images_placeholder, labels_placeholder


def my_eval():
  with tf.Graph().as_default():

    v1 = tf.Variable(0)

    images_placeholder, labels_placeholder = placeholder_inputs(FLAGS.batch_size)

    # Get images and labels for CIFAR-10.
    val_images, val_labels = my_input.inputs(False)

    init_op = tf.initialize_all_variables()

    coord = tf.train.Coordinator()

    with tf.Session() as sess:

      sess.run(init_op)

      saver = tf.train.Saver()
      # Start the queue runners.

      threads = tf.train.start_queue_runners(sess=sess, coord=coord)

      summary_op = tf.merge_all_summaries()
      summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir,
                                              graph_def=sess.graph_def)


      # Build a Graph that computes the logits predictions from the
      # inference model.
      logits = my_cifar.inference(images_placeholder)

      acc = my_cifar.evaluation(logits, labels_placeholder)

      ckpt = tf.train.get_checkpoint_state(checkpoint_dir=FLAGS.checkpoint_dir)
      print ckpt.model_checkpoint_path
      if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print('Restored!')

      images_val_r, labels_val_r = sess.run([val_images, val_labels])
      val_feed = {images_placeholder: images_val_r,
                  labels_placeholder: labels_val_r}

      tf.scalar_summary('Acc', acc)

      print('Calculating Acc  :')

      acc_r = sess.run(acc, feed_dict=val_feed)
      print(acc_r)

      # Write results to TensorBoard
      summary_str = sess.run(summary_op)
      summary_writer.add_summary(summary_str)


      coord.join(threads)


def main(argv=None):
  my_eval()


if __name__ == '__main__':
  tf.app.run()

【问题讨论】:

  • @mrry 如果您在这里看到任何错误,请告诉我
  • 尝试为每个变量定义一个名称,可能是加载检查点时自动生成的变量名称不同。

标签: python tensorflow


【解决方案1】:

您正在尝试加载原始网络中不存在的变量,我相信省略

    v1 = tf.Variable(0)

会解决问题的。

如果你想添加一个新变量,你需要不同的加载它,加载方法应该是这样的:

reader = tf.train.NewCheckpointReader(os.path.join(checkpoint_dir, ckpt_name))
restore_dict = dict()
for v in tf.trainable_variables():
    tensor_name = v.name.split(':')[0]
    if reader.has_tensor(tensor_name):
        print('has tensor ', tensor_name)
        restore_dict[tensor_name] = v
    # put the logic of the new/modified variable here and assign to the restore_dict, i.e. 
    # restore_dict['my_var_scope/my_var'] = get_my_variable()

【讨论】:

    猜你喜欢
    • 2017-07-30
    • 2019-04-03
    • 2016-09-29
    • 1970-01-01
    • 2018-02-16
    • 2019-08-20
    • 2018-02-25
    • 2016-06-14
    • 1970-01-01
    相关资源
    最近更新 更多