【问题标题】:tf.scatter_add causes error in looptf.scatter_add 导致循环错误
【发布时间】:2017-10-28 05:23:36
【问题描述】:

我发现 tf.scatter_add 的行为非常奇怪:我创建了一个 tf.while_loop,它创建了一个包装在 tf.Variable 中的张量。

如果我没有在循环外的变量中添加一些东西,tensorflow 会导致一个错误,告诉我该变量是不可变的。

这是一个 MWE:

import tensorflow as tf        

m = 25
batch_num = 32
num_bus = 50

C = tf.zeros((m, batch_num, num_bus, m),tf.float64)
C = tf.Variable(C)

c = tf.ones((batch_num, num_bus, m), tf.float64)
#C = tf.scatter_add(C,0,c)

k = tf.constant(1)

stop_cond = lambda k,C: k<m

def construct_C(k, C):
    upd_c = c+1
    C = tf.scatter_add(C,k,upd_c)
    return k+1,C

k,C = tf.while_loop(stop_cond,construct_C, (k,C))

sess = tf.Session()
sess.run(tf.global_variables_initializer())
C1 = sess.run(C)

此代码导致错误:TypeError: 'ScatterAdd' Op requires that input 'ref' be a mutable tensor (e.g.: a tf.Variable)。但是,当我取消注释 C = tf.scatter_add(C,0,c) 时,一切正常。

这是故意的吗?我做错了什么?

【问题讨论】:

  • 我认为可能是 tf.while_loop 将 C 变成了不可变张量,但我不能太确定。您可以尝试使用常规的 python 循环。

标签: python tensorflow


【解决方案1】:

听起来有些 while_loop 原语不知道变量(相反,他们知道 ref 类型的张量)。这看起来像是代码中的错误 - 请在 github 上提交问题。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2020-08-01
    • 2015-12-10
    • 1970-01-01
    相关资源
    最近更新 更多