5.4 TensorFlow模型持久化
5.4.1. ckpt文件保存方法
在对模型进行加载时候,需要定义出与原来的计算图结构完全相同的计算图,然后才能进行加载,并且不需要对定义出来的计算图进行初始化操作。
这样保存下来的模型,会在其文件夹下生成三个文件,分别是:
* .ckpt.meta文件,保存tensorflow模型的计算图结构。
* .ckpt文件,保存计算图下所有变量的取值。
* checkpoint文件,保存目录下所有模型文件列表。
import tensorflow as tf #保存计算两个变量和的模型 v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1)) v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1)) result = v1 + v2 init_op = tf.global_variables_initializer() saver = tf.train.Saver() with tf.Session() as sess: sess.run(init_op) saver.save(sess, "Saved_model/model.ckpt") #加载保存了两个变量和的模型 with tf.Session() as sess: saver.restore(sess, "Saved_model/model.ckpt") print sess.run(result) INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt [-1.6226364] #直接加载持久化的图。因为之前没有导出v3,所以这里会报错 saver = tf.train.import_meta_graph("Saved_model/model.ckpt.meta") v3 = tf.Variable(tf.random_normal([1], stddev=1, seed=1)) with tf.Session() as sess: saver.restore(sess, "Saved_model/model.ckpt") print sess.run(v1) print sess.run(v2) print sess.run(v3) INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt [-0.81131822] [-0.81131822] # 变量重命名,这样可以通过字典将模型保存时的变量名和需要加载的变量联系起来 v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "other-v1") v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "other-v2") saver = tf.train.Saver({"v1": v1, "v2": v2})