【问题标题】:tensorflow: saving and restoring sessiontensorflow:保存和恢复会话
【发布时间】:2016-04-02 16:31:13
【问题描述】:

我正在尝试实施答案中的建议: Tensorflow: how to save/restore a model?

我有一个对象以sklearn 样式包装tensorflow 模型。

import tensorflow as tf
class tflasso():
    saver = tf.train.Saver()
    def __init__(self,
                 learning_rate = 2e-2,
                 training_epochs = 5000,
                    display_step = 50,
                    BATCH_SIZE = 100,
                    ALPHA = 1e-5,
                    checkpoint_dir = "./",
             ):
        ...

    def _create_network(self):
       ...


    def _load_(self, sess, checkpoint_dir = None):
        if checkpoint_dir:
            self.checkpoint_dir = checkpoint_dir

        print("loading a session")
        ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            self.saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            raise Exception("no checkpoint found")
        return

    def fit(self, train_X, train_Y , load = True):
        self.X = train_X
        self.xlen = train_X.shape[1]
        # n_samples = y.shape[0]

        self._create_network()
        tot_loss = self._create_loss()
        optimizer = tf.train.AdagradOptimizer( self.learning_rate).minimize(tot_loss)

        # Initializing the variables
        init = tf.initialize_all_variables()
        " training per se"
        getb = batchgen( self.BATCH_SIZE)

        yvar = train_Y.var()
        print(yvar)
        # Launch the graph
        NUM_CORES = 3  # Choose how many cores to use.
        sess_config = tf.ConfigProto(inter_op_parallelism_threads=NUM_CORES,
                                                           intra_op_parallelism_threads=NUM_CORES)
        with tf.Session(config= sess_config) as sess:
            sess.run(init)
            if load:
                self._load_(sess)
            # Fit all training data
            for epoch in range( self.training_epochs):
                for (_x_, _y_) in getb(train_X, train_Y):
                    _y_ = np.reshape(_y_, [-1, 1])
                    sess.run(optimizer, feed_dict={ self.vars.xx: _x_, self.vars.yy: _y_})
                # Display logs per epoch step
                if (1+epoch) % self.display_step == 0:
                    cost = sess.run(tot_loss,
                            feed_dict={ self.vars.xx: train_X,
                                    self.vars.yy: np.reshape(train_Y, [-1, 1])})
                    rsq =  1 - cost / yvar
                    logstr = "Epoch: {:4d}\tcost = {:.4f}\tR^2 = {:.4f}".format((epoch+1), cost, rsq)
                    print(logstr )
                    self.saver.save(sess, self.checkpoint_dir + 'model.ckpt',
                       global_step= 1+ epoch)

            print("Optimization Finished!")
        return self

当我跑步时:

tfl = tflasso()
tfl.fit( train_X, train_Y , load = False)

我得到输出:

Epoch:   50 cost = 38.4705  R^2 = -1.2036
    b1: 0.118122
Epoch:  100 cost = 26.4506  R^2 = -0.5151
    b1: 0.133597
Epoch:  150 cost = 22.4330  R^2 = -0.2850
    b1: 0.142261
Epoch:  200 cost = 20.0361  R^2 = -0.1477
    b1: 0.147998

但是,当我尝试恢复参数时(即使没有杀死对象): tfl.fit( train_X, train_Y , load = True)

我得到了奇怪的结果。首先,加载的值与保存的不对应。

loading a session
loaded b1: 0.1          <------- Loaded another value than saved
Epoch:   50 cost = 30.8483  R^2 = -0.7670
    b1: 0.137484  

什么是正确的加载方式,可能首先检查保存的变量?

【问题讨论】:

  • tensorflow 文档没有非常基本的示例,您必须在示例文件夹中挖掘并自行理解它

标签: python scikit-learn tensorflow


【解决方案1】:

TL;DR:您应该尝试重新编写这个类,以便 (i) 只调用一次 self.create_network(),并且 (ii) 在构造 tf.train.Saver() 之前调用。

这里有两个微妙的问题,一个是代码结构,另一个是tf.train.Saver constructor 的默认行为。当您构造一个不带参数的保护程序时(如在您的代码中),它会收集程序中的当前变量集,并将操作添加到图形中以保存和恢复它们。在你的代码中,当你调用tflasso()时,它会构造一个saver,不会有变量(因为create_network()还没有被调用)。因此,检查点应该是空的。

第二个问题是,默认情况下,保存的检查点的格式是从name property of a variable 到其当前值的映射。如果你创建了两个同名的变量,它们会被 TensorFlow 自动“唯一化”:

v = tf.Variable(..., name="weights")
assert v.name == "weights"
w = tf.Variable(..., name="weights")
assert v.name == "weights_1"  # The "_1" is added by TensorFlow.

这样做的结果是,当您在第二次调用 tfl.fit() 时调用 self.create_network() 时,变量的名称都将与检查点中存储的名称不同——或者如果保护程序有建网后。 (您可以通过将 name-Variable 字典传递给保护程序构造函数来避免这种行为,但这通常很尴尬。)

有两种主要的解决方法:

  1. 在每次调用 tflasso.fit() 时,重新创建整个模型,方法是定义一个新的 tf.Graph,然后在该图中构建网络并创建一个 tf.train.Saver

  2. 推荐创建网络,然后在tflasso 构造函数中创建tf.train.Saver,并在每次调用tflasso.fit() 时重用此图。请注意,您可能需要做更多的工作来重组事物(特别是,我不确定您对 self.Xself.xlen 做了什么)但应该可以通过 placeholders 和喂食来实现这一点。

【讨论】:

  • 谢谢! xlenself._create_network() 中用于设置X 的输入大小(占位符初始化:self.vars.xx = tf.placeholder("float", shape=[None, self.xlen]))。根据您的说法,首选方法是将xlen 传递给初始化程序。
  • 有没有办法在对象重新初始化时重置 uniquifier / 清除旧的 tf 变量?
  • 为此,您需要创建一个新的 tf.Graph 并将其设为默认值,然后再 (i) 创建网络并 (ii) 创建一个 Saver。如果您将tflasso.fit() 的主体包裹在with tf.Graph().as_default(): 块中,并将Saver 构造移动到该块内,则每次调用fit() 时名称应该相同。
猜你喜欢
  • 1970-01-01
  • 2018-02-07
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2016-07-31
  • 2016-08-16
相关资源
最近更新 更多