【发布时间】:2018-04-11 12:03:55
【问题描述】:
Tensorflow Eager Execution 中基于名称的保存与基于对象的保存有什么区别?
【问题讨论】:
标签: tensorflow
Tensorflow Eager Execution 中基于名称的保存与基于对象的保存有什么区别?
【问题讨论】:
标签: tensorflow
TensorFlow 传统上使用全局变量名来匹配检查点值与图中的变量。基本上只是变量的.name 属性:
import tensorflow as tf
tf.enable_eager_execution()
dense = tf.keras.layers.Dense(2)
dense(tf.ones([1, 1]))
print(dense.variables[0].name)
打印:
dense/kernel:0
这是tf.train.Saver 写入检查点的名称,以及它用来匹配恢复值的键。当 Python 程序包含单个 TensorFlow 模型或模型构建被隔离时(如 tf.estimator.Estimator,它会构建它从头开始包装在新的 Graph 中的模型),它运行良好。
基于对象的检查点,tf.contrib.eager.Checkpoint / tfe.Checkpoint,旨在在 Python 程序更改或在同一 Python 程序中使用多个 TensorFlow 模型时使此变量匹配更加稳健。它通过构建具有命名边的对象的依赖图并将其与检查点一起保存来做到这一点:
(来自eager GAN example的可视化;黑色节点是Layer对象,蓝色是变量,红色是优化器,橙色是优化器创建的槽变量)
这些命名依赖是在将属性分配给Checkpointable 对象时自动创建的,包括tf.keras.Model。例如self.conv1 = layers.Conv2D(...) makes a dependency edge named "conv1" when self is a tf.keras.Model。
恢复时,模型的结构(对象及其命名的边)应该匹配,不一定是变量的确切名称。
回到Dense层,我们可以为它做一个检查点,然后将它恢复到第二个变量名不匹配的对象中:
import tensorflow.contrib.eager as tfe
save_checkpoint = tfe.Checkpoint(dense=dense)
dense.variables[0].assign([[1., 2.]])
save_path = save_checkpoint.save("/tmp/tensorflow/ckpt")
# save_path="/tmp/tensorflow/ckpt-1"
然后在恢复的时候,还是在同一个程序里:
second_dense = tf.keras.layers.Dense(2)
restore_checkpoint = tfe.Checkpoint(dense=second_dense)
restore_checkpoint.restore(save_path)
second_dense(tf.ones([1, 1]))
print(second_dense.variables[0])
打印:
<tf.Variable 'dense_1/kernel:0' shape=(1, 2) dtype=float32, numpy=array([[1., 2.]], dtype=float32)>
[[1., 2.]] 的值在被Dense 层使用(创建时恢复)之前已恢复,尽管名称不同(dense_1/kernel 而不是dense/kernel)。
虽然在急切执行时特别有用,但基于对象的保存在图形构建时也很有效。只需添加run_restore_ops():
import tensorflow as tf
import tensorflow.contrib.eager as tfe
dense = tf.keras.layers.Dense(2)
dense(tf.ones([1, 1]))
save_checkpoint = tfe.Checkpoint(dense=dense)
assign_op = tf.group(dense.variables[0].assign([[1., 2.]]),
dense.variables[1].assign([3., 4.]))
second_dense = tf.keras.layers.Dense(2)
restore_checkpoint = tfe.Checkpoint(dense=second_dense)
second_dense(tf.ones([1, 1]))
with tf.Session() as session:
session.run(assign_op)
save_path = save_checkpoint.save("/tmp/tensorflow/ckpt")
restore_checkpoint.restore(save_path).assert_consumed().run_restore_ops()
print(session.run(second_dense.variables[0]))
打印:
[[1. 2.]]
有用的资源:
tfe.Checkpoint 的文档:https://www.tensorflow.org/api_docs/python/tf/contrib/eager/Checkpoint
tfe.Checkpointable 类的文档:https://www.tensorflow.org/api_docs/python/tf/contrib/eager/Checkpointable
【讨论】: