我相信您需要的是ticket #206 中讨论的assign_slice_update。 不过目前还没有。
更新:现已实施。见 jdehesa 的回答:https://stackoverflow.com/a/43139565/6531137
在assign_slice_update(或scatter_nd())可用之前,您可以构建所需行的块,其中包含您不想修改的值以及要更新的所需值,如下所示:
import tensorflow as tf
a = tf.Variable(tf.ones([10,36,36]))
i = 3
j = 5
# Gather values inside the a[i,...] block that are not on column j
idx_before = tf.concat(1, [tf.reshape(tf.tile(tf.Variable([i]), [j]), [-1, 1]), tf.reshape(tf.range(j), [-1, 1])])
values_before = tf.gather_nd(a, idx_before)
idx_after = tf.concat(1, [tf.reshape(tf.tile(tf.Variable([i]), [36-j-1]), [-1, 1]), tf.reshape(tf.range(j+1, 36), [-1, 1])])
values_after = tf.gather_nd(a, idx_after)
# Build a subset of tensor `a` with the values that should not be touched and the values to update
block = tf.concat(0, [values_before, 5*tf.ones([1, 36]), values_after])
d = tf.scatter_update(a, i, block)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
sess.run(d)
print(a.eval()[3,4:7,:]) # Print a subset of the tensor to verify
该示例生成一个张量并执行a[i,j,:] = 5。大部分复杂性在于获取我们不想修改的值,a[i,~j,:](否则scatter_update() 将替换这些值)。
如果您想按照您的要求执行T[i,k,:] = a[1,1,:],您需要将前面示例中的5*tf.ones([1, 36]) 替换为tf.gather_nd(a, [[1, 1]])。
另一种方法是为 tf.select() 从中创建所需元素的掩码并将其分配回变量,如下所示:
import tensorflow as tf
a = tf.Variable(tf.zeros([10,36,36]))
i = tf.Variable([3])
j = tf.Variable([5])
# Build a mask using indices to perform [i,j,:]
atleast_2d = lambda x: tf.reshape(x, [-1, 1])
indices = tf.concat(1, [atleast_2d(tf.tile(i, [36])), atleast_2d(tf.tile(j, [36])), atleast_2d(tf.range(36))])
mask = tf.cast(tf.sparse_to_dense(indices, [10, 36, 36], 1), tf.bool)
to_update = 5*tf.ones_like(a)
out = a.assign( tf.select(mask, to_update, a) )
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
sess.run(out)
print(a.eval()[2:5,5,:])
它在内存方面的效率可能较低,因为它需要两倍的内存来处理 a 类似 to_update 变量,但您可以轻松修改最后一个示例以从 @987654338 获得梯度保留操作@节点。您可能也有兴趣查看其他 StackOverflow 问题:Conditional assignment of tensor values in TensorFlow。
当合适的 TensorFlow 函数可用时,应将这些不雅的扭曲替换为调用正确的 TensorFlow 函数。