get_variable() 函数创建一个新变量或返回一个由get_variable() 先前创建的变量。它不会返回使用tf.Variable() 创建的变量。这是一个简单的例子:
>>> with tf.variable_scope("foo"):
... bar1 = tf.get_variable("bar", (2,3)) # create
...
>>> with tf.variable_scope("foo", reuse=True):
... bar2 = tf.get_variable("bar") # reuse
...
>>> with tf.variable_scope("", reuse=True): # root variable scope
... bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above)
...
>>> (bar1 is bar2) and (bar2 is bar3)
True
如果您没有使用tf.get_variable() 创建变量,您有几个选择。首先,您可以使用tf.global_variables()(正如@mrry 建议的那样):
>>> bar1 = tf.Variable(0.0, name="bar")
>>> bar2 = [var for var in tf.global_variables() if var.op.name=="bar"][0]
>>> bar1 is bar2
True
或者你可以像这样使用tf.get_collection():
>>> bar1 = tf.Variable(0.0, name="bar")
>>> bar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="bar")[0]
>>> bar1 is bar2
True
编辑
你也可以使用get_tensor_by_name():
>>> bar1 = tf.Variable(0.0, name="bar")
>>> graph = tf.get_default_graph()
>>> bar2 = graph.get_tensor_by_name("bar:0")
>>> bar1 is bar2
False, bar2 is a Tensor througn convert_to_tensor on bar1. but bar1 equal
bar2 in value.
回想一下,张量是操作的输出。它与操作同名,加上:0。如果操作有多个输出,则它们的名称与操作相同,加上:0、:1、:2 等。