【问题标题】:matrix multiplication broadcast in nd4jnd4j中的矩阵乘法广播
【发布时间】:2019-05-06 15:44:47
【问题描述】:

在python中,假设

a = np.array(range(0,12)).reshape(2,2,3)
b = np.array(range(0,6)).reshape(3,2)
c = np.matmul(a,b) // a @ b

我们有

a: array([[[ 0,  1,  2],
        [ 3,  4,  5]],

       [[ 6,  7,  8],
        [ 9, 10, 11]]])

b: array([[0, 1],
       [2, 3],
       [4, 5]])

c: array([[[10, 13],
        [28, 40]],

       [[46, 67],
        [64, 94]]])

有人可以帮助我在没有 for 循环的情况下在 java nd4j 中实现等效操作吗?我试过broadcast.mul,但事实证明broadcast.mul 是逐元素乘法。没有找到mmul的广播操作。

【问题讨论】:

    标签: nd4j


    【解决方案1】:

    我自己想出来的。答案如下所示,以防有人需要。 使用Nd4j.tensorMmul,可以轻松实现矩阵广播。例如

    val a = Nd4j.create(0d to 11d by 1d toArray, Array[Int](2, 2, 3))
    val b = Nd4j.create(0d to 5d by 1d toArray, Array[Int](3, 2))
    Nd4j.tensorMmul(a, b, Array(Array(2), Array(0))) // matrix broadcast
    

    这是 scala 的代码。对于java,你只需要修改代码来创建数组。

    【讨论】:

      猜你喜欢
      • 2015-01-07
      • 1970-01-01
      • 1970-01-01
      • 2020-07-21
      • 1970-01-01
      • 2016-06-27
      • 2014-08-06
      • 2017-07-07
      • 1970-01-01
      相关资源
      最近更新 更多