【问题标题】:Initializing variables with imported tensors from another graph使用从另一个图中导入的张量初始化变量
【发布时间】:2017-09-08 08:21:34
【问题描述】:

我在 python3 中使用 tensorflow(版本:v1.1.0-13-g8ddd727 1.1.0)(Python 3.4.3(默认,2016 年 11 月 17 日,01:08:31)[GCC 4.8.4] 在 linux 上) , 它是从源代码安装并基于 GPU 的。

我想知道是否可以使用从另一个会话中导入的张量来初始化变量,因为 tensorflow 文档没有提到它,我在 stackoverflow 上找到了它。

train_dir = './gan/train_logs'
    ckpt = tf.train.latest_checkpoint(train_dir)
    filename = ".".join([ckpt, 'meta'])
    print(filename)
    saver = tf.train.import_meta_graph(filename)
    saver.restore(sess, ckpt)
    test = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator')

这里的张量已经成功导入,我想用它们来初始化同一个生成器。

感谢您的帮助!

【问题讨论】:

    标签: python-3.x tensorflow deep-learning


    【解决方案1】:

    您所要做的就是创建 tf.assign 操作。

    所以你这样做:

    old_weights = .... # your loading
    
    new_weights = tf.Variable( ... ) # any initialisation here!
    
    initialise_new_weights = tf.assign(new_weights, old_weights)
    
    with tf.train.MonitoredSession() as sess:
      # at this point new_weights are randomly initialised
      sess.run(initialise_new_weight) # now they are initialised to your values
    

    或者你可以直接传递初始化参数

    old_weights = .... # your loading
    
    new_weights = tf.Variable( ..., initializer = tf.constant_initialiser(old_weights) ) 
    
    with tf.train.MonitoredSession() as sess:
      # they are initialised to your values
    

    【讨论】:

    • 非常感谢您的帮助!它确实有效!问题肯定是这个 tf.assign 并且我没有意识到我必须将它作为一个操作来运行。
    猜你喜欢
    • 1970-01-01
    • 2016-05-01
    • 1970-01-01
    • 1970-01-01
    • 2017-12-10
    • 1970-01-01
    • 1970-01-01
    • 2018-12-29
    相关资源
    最近更新 更多