【发布时间】:2017-10-28 05:23:36
【问题描述】:
我发现 tf.scatter_add 的行为非常奇怪:我创建了一个 tf.while_loop,它创建了一个包装在 tf.Variable 中的张量。
如果我没有在循环外的变量中添加一些东西,tensorflow 会导致一个错误,告诉我该变量是不可变的。
这是一个 MWE:
import tensorflow as tf
m = 25
batch_num = 32
num_bus = 50
C = tf.zeros((m, batch_num, num_bus, m),tf.float64)
C = tf.Variable(C)
c = tf.ones((batch_num, num_bus, m), tf.float64)
#C = tf.scatter_add(C,0,c)
k = tf.constant(1)
stop_cond = lambda k,C: k<m
def construct_C(k, C):
upd_c = c+1
C = tf.scatter_add(C,k,upd_c)
return k+1,C
k,C = tf.while_loop(stop_cond,construct_C, (k,C))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
C1 = sess.run(C)
此代码导致错误:TypeError: 'ScatterAdd' Op requires that input 'ref' be a mutable tensor (e.g.: a tf.Variable)。但是,当我取消注释 C = tf.scatter_add(C,0,c) 时,一切正常。
这是故意的吗?我做错了什么?
【问题讨论】:
-
我认为可能是 tf.while_loop 将 C 变成了不可变张量,但我不能太确定。您可以尝试使用常规的 python 循环。
标签: python tensorflow