【问题标题】:TensorFlow: getting variable by nameTensorFlow:按名称获取变量
【发布时间】:2016-02-28 04:57:43
【问题描述】:

在使用 TensorFlow Python API 时,我创建了一个变量(未在构造函数中指定其 name),其 name 属性的值为 "Variable_23:0"。当我尝试使用tf.get_variable("Variable23") 选择此变量时,会创建一个名为"Variable_23_1:0" 的新变量。如何正确选择"Variable_23" 而不是创建一个新的?

我要做的是按名称选择变量,然后重新初始化它,以便微调权重。

【问题讨论】:

    标签: tensorflow


    【解决方案1】:

    get_variable() 函数创建一个新变量或返回一个由get_variable() 先前创建的变量。它不会返回使用tf.Variable() 创建的变量。这是一个简单的例子:

    >>> with tf.variable_scope("foo"):
    ...   bar1 = tf.get_variable("bar", (2,3)) # create
    ... 
    >>> with tf.variable_scope("foo", reuse=True):
    ...   bar2 = tf.get_variable("bar")  # reuse
    ... 
    
    >>> with tf.variable_scope("", reuse=True): # root variable scope
    ...   bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above)
    ... 
    >>> (bar1 is bar2) and (bar2 is bar3)
    True
    

    如果您没有使用tf.get_variable() 创建变量,您有几个选择。首先,您可以使用tf.global_variables()(正如@mrry 建议的那样):

    >>> bar1 = tf.Variable(0.0, name="bar")
    >>> bar2 = [var for var in tf.global_variables() if var.op.name=="bar"][0]
    >>> bar1 is bar2
    True
    

    或者你可以像这样使用tf.get_collection()

    >>> bar1 = tf.Variable(0.0, name="bar")
    >>> bar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="bar")[0]
    >>> bar1 is bar2
    True
    

    编辑

    你也可以使用get_tensor_by_name():

    >>> bar1 = tf.Variable(0.0, name="bar")
    >>> graph = tf.get_default_graph()
    >>> bar2 = graph.get_tensor_by_name("bar:0")
    >>> bar1 is bar2
    False, bar2 is a Tensor througn convert_to_tensor on bar1. but bar1 equal 
    bar2 in value.
    

    回想一下,张量是操作的输出。它与操作同名,加上:0。如果操作有多个输出,则它们的名称与操作相同,加上:0:1:2 等。

    【讨论】:

    • 你能解释一下这里发生了什么吗?对我没用
    • @Matian2040 我改进了我的答案,希望现在应该清楚了。 :)
    【解决方案2】:

    按名称获取变量的最简单方法是在tf.global_variables() 集合中搜索它:

    var_23 = [v for v in tf.global_variables() if v.name == "Variable_23:0"][0]
    

    这适用于现有变量的临时重用。 Sharing Variables tutorial 中介绍了一种更结构化的方法(当您希望在模型的多个部分之间共享变量时)。

    【讨论】:

    • 你能提供一个共享变量的例子吗?它一直要求我重用,我明白这意味着什么,但我无法让 tensorflow 工作。
    • UPDATE:WARNING:tensorflow:From :1 in .: all_variables (from tensorflow.python.ops.variables) 已弃用,将在之后删除2017-03-02。更新说明:请改用 tf.global_variables。
    • 嗨,谢谢。我有不同的情况。你能告诉我我可以使用get_tensor_by_nameget variable by name 来得到tf.layers.dense 定义的东西,例如,means 在这个代码示例here 中。
    • 不迭代所有创建的操作并改用密钥是不是更好的方法?
    • 链接已失效。太旧了?
    【解决方案3】:

    如果您想从模型中获取任何存储的变量,请使用tf.train.load_variable("model_folder_name","Variable name")

    【讨论】:

      【解决方案4】:

      根据@mrry 的回答,我认为创建和使用以下函数会更好,因为还有局部变量,以及其他不在全局变量中的变量(它们在不同的集合中):

      def get_var_by_name(query_name, var_list):
          """
          Get Variable by name
      
          e.g.
          local_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES)
          the_var = get_var_by_name(local_vars, 'accuracy/total:0')
          """
          target_var = None
          for var in var_list:
              if var.name==query_name:
                  target_var = var
                  break
          return target_var
      

      【讨论】:

        猜你喜欢
        • 2016-08-05
        • 1970-01-01
        • 2014-02-08
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2012-02-11
        • 1970-01-01
        相关资源
        最近更新 更多