【问题标题】:How to initialize a tf.Variable with a tf.constant or a numpy array?如何使用 tf.constant 或 numpy 数组初始化 tf.Variable?
【发布时间】:2019-04-29 19:16:20
【问题描述】:

我正在尝试在 tf.InteractiveSession() 中初始化 tf.Variable()。我已经有一些预训练的权重,它们是单独的 numpy 文件。如何使用这些 numpy 值有效地初始化变量?

我已经经历了以下选择:

  1. 使用tf.assign()
  2. 在创建tf.Variable() 期间直接使用sess.run()

似乎值未正确初始化。 以下是我尝试过的一些代码。让我知道哪个是正确的?

def read_numpy(file):
    return np.fromfile(file,dtype='f')

def build_network():
    with tf.get_default_graph().as_default():
        x = tf.Variable(tf.constant(read_numpy('foo.npy')),name='var1')
        sess = tf.get_default_session()
        with sess.as_default():
            sess.run(tf.global_variables_initializer())

sess = tf.InteractiveSession()
with sess.as_default():
    build_network()

这是正确的方法吗?我已经打印了session 对象,它与自始至终使用的会话相同。

编辑:目前似乎使用sess.run(tf.global_variables_initializer()) 正在调用随机初始化操作

【问题讨论】:

    标签: python numpy tensorflow global-variables


    【解决方案1】:

    tf.Variable() 接受 numpy 数组作为初始值:

    import tensorflow as tf
    import numpy as np
    
    init = np.ones((2, 2))
    x = tf.Variable(init) # <-- set initial value to assign to a variable
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer()) # <-- this will assign the init value
        print(x.eval())
    # [[1. 1.]
    #  [1. 1.]]
    

    所以只需使用numpy数组进行初始化,无需先将其转换为张量。

    或者,您也可以使用 tf.Variable.load() 将 numpy 数组中的值分配给会话上下文中的变量:

    import tensorflow as tf
    import numpy as np
    
    x = tf.Variable(tf.zeros((2, 2)))
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        init = np.ones((2, 2))
        x.load(init)
        print(x.eval())
    # [[1. 1.]
    #  [1. 1.]]
    

    【讨论】:

    • 可以使用tf.get_variable() 吗? ..我用过它,它奏效了。我有一个属性为initializer=tf.constant(read_numpy('foo.npy'))
    • 是的,你可以。 tf.Variable()tf.get_variable() 之间的区别在于tf.Variable() 创建了一个新变量,而get_variable() 可以创建一个新变量,也可以返回一个现有变量。更多阅读this答案
    • 哦,好吧,我第一次使用tf.get_variable() 时确实遇到了一个错误,因为命名发生冲突.. 但后来我将它包含在tf.variable_scope() 中并且得到了处理。感谢您的回复!
    • 很高兴它有帮助。我仍然会使用x = tf.Variable(read_numpy('foo.npy'))。这样它会更短。
    猜你喜欢
    • 2015-08-05
    • 2014-06-28
    • 1970-01-01
    • 1970-01-01
    • 2017-07-27
    • 2018-12-04
    • 2011-05-30
    • 2019-07-03
    • 1970-01-01
    相关资源
    最近更新 更多