* 也适用于 SparseTensor,您的问题似乎与 SparseTensor 本身有关,您可能提供了超出范围的索引你给它的形状,考虑这个例子:
A_t = tf.SparseTensor(indices=[[0,6],[4,4]], values=[3.2,5.1], dense_shape=(5,5))
请注意列索引6 大于指定的形状,该形状应具有最大5 列,这会产生与您显示的相同的错误:
b = np.array([1.0, 2.0, 0.0, 0.0, 1.0])
B_t = tf.Variable(b, dtype=tf.float32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(A_t * B_t))
InvalidArgumentError(参见上面的回溯):提供的索引是
w.r.t. 越界具有广播形状的密集面
这是一个工作示例:
A_t = tf.SparseTensor(indices=[[0,3],[4,4]], values=[3.2,5.1], dense_shape=(5,5))
b = np.array([1.0, 2.0, 0.0, 0.0, 1.0])
B_t = tf.Variable(b, dtype=tf.float32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(A_t * B_t))
# SparseTensorValue(indices=array([[0, 3],
# [4, 4]], dtype=int64), values=array([ 0. , 5.0999999], dtype=float32), dense_shape=array([5, 5], dtype=int64))