有多种方法可以做到这一点。这是一个tf.one_hot()(测试代码):
import tensorflow as tf
a = tf.constant( [[1,1,1,1], [2,2,2,2],[3,3,3,3], [4, 4, 4, 4]] )
b = tf.constant( [ 1, 3, 4 ] )
one_hot = tf.one_hot( b, a.get_shape()[ 0 ].value, dtype = a.dtype )
mask = 1 - tf.reduce_sum( one_hot, axis = 0 )
res = a * mask[ ..., None ]
with tf.Session() as sess:
print( sess.run( res ) )
或者这个tf.scatter_nd()(测试代码):
import tensorflow as tf
a = tf.constant( [[1,1,1,1], [2,2,2,2], [3,3,3,3], [4, 4, 4, 4]] )
b = tf.constant( [ 1, 3 ] )
mask = 1 - tf.scatter_nd( b[ ..., None ], tf.ones_like( b ), shape = [ a.get_shape()[ 0 ].value ] )
res = a * mask[ ..., None ]
with tf.Session() as sess:
print( sess.run( res ) )
都会输出:
[[1 1 1 1]
[0 0 0 0]
[3 3 3 3]
[0 0 0 0]]
根据需要。