【问题标题】:no broadcasting for tf.matmul in tensorflow for 4D 3D tensors在 4D 3D 张量的 tensorflow 中没有广播 tf.matmul
【发布时间】:2017-09-12 19:52:39
【问题描述】:

首先我在这里找到另一个问题No broadcasting for tf.matmul in TensorFlow
但是这个问题并不能解决我的问题。

我的问题是一批矩阵乘以另一批向量。

x=tf.placeholder(tf.float32,shape=[10,1000,3,4])
y=tf.placeholder(tf.float32,shape=[1000,4])

x是一组矩阵。共有10*1000个矩阵。每个矩阵的形状为[3,4]
y 是一组向量。有 1000 个向量。每个向量的形状为[4]
x 的 1 和 y 的 0 相同。(这里是 1000)
如果 tf.matmul 支持广播,我可以写

y=tf.reshape(y,[1,1000,4,1])
result=tf.matmul(x,y)
result=tf.reshape(result,[10,1000,3])

但是 tf.matmul 不支持广播
如果我使用上面提到的问题的方法

x=tf.reshape(x,[10*1000*3,4])
y=tf.transpose(y,perm=[1,0]) #[4,1000]
result=tf.matmul(x,y)
result=tf.reshape(result,[10,1000,3,1000])

结果的形状是 [10,1000,3,1000],而不是 [10,1000,3]。
不知道怎么去掉多余的1000
如何获得与支持广播的 tf.matmul 相同的结果?

【问题讨论】:

    标签: tensorflow


    【解决方案1】:

    我自己解决。

    x=tf.transpose(x,perm=[1,0,2,3]) #[1000,10,3,4]
    x=tf.reshape(x,[1000,30,4])
    y=tf.reshape(y,[1000,4,1])
    result=tf.matmul(x,y) #[1000,30,1]
    result=tf.reshape(result,[1000,10,3])
    result=tf.transpose(result,perm=[1,0,2]) #[10,1000,3]
    

    【讨论】:

    • 很遗憾,没有更好的办法,即使使用tf.tensordot
    【解决方案2】:

    here 所示,您可以使用函数来解决:

    def broadcast_matmul(A, B):
      "Compute A @ B, broadcasting over the first `N-2` ranks"
      with tf.variable_scope("broadcast_matmul"):
        return tf.reduce_sum(A[..., tf.newaxis] * B[..., tf.newaxis, :, :],
                             axis=-2)
    

    【讨论】:

      猜你喜欢
      • 2016-10-29
      • 1970-01-01
      • 2017-11-01
      • 1970-01-01
      • 2017-05-21
      • 2020-06-12
      • 1970-01-01
      • 2021-07-19
      • 2023-03-19
      相关资源
      最近更新 更多