【问题标题】:Fastest way to multiply arrays of matrices in Python (numpy)在 Python 中乘以矩阵数组的最快方法(numpy)
【发布时间】:2011-11-12 10:58:00
【问题描述】:

我有两个 2×2 复数矩阵数组,我想知道将它们相乘的最快方法是什么。 (我想对矩阵数组的元素做矩阵乘法。)目前我有

numpy.array(map(lambda i: numpy.dot(m1[i], m2[i]), range(l)))

但是还有比这更好的吗?

谢谢,

v923z

【问题讨论】:

    标签: arrays performance matrix numpy multiplication


    【解决方案1】:

    numpy.einsum 是这个问题的最佳解决方案,它在 DaveP 参考文献的底部被提及。代码很干净,很容易理解,并且比循环遍历数组并逐个进行乘法运算快一个数量级。下面是一些示例代码:

    import numpy
    l = 100
    
    m1 = rand(l,2,2)
    m2 = rand(l,2,2)
    
    m3 = numpy.array(map(lambda i: numpy.dot(m1[i], m2[i]), range(l)))
    m3e = numpy.einsum('lij,ljk->lik', m1, m2)
    
    %timeit numpy.array(map(lambda i: numpy.dot(m1[i], m2[i]), range(l)))
    %timeit numpy.einsum('lij,ljk->lik', m1, m2)
    
    print np.all(m3==m3e)
    

    以下是在 ipython 笔记本中运行时的返回值:
    1000 次循环,最好的 3 次:每个循环 479 µs
    10000 次循环,最好的 3 次:每个循环 48.9 µs
    是的

    【讨论】:

    【解决方案2】:

    我认为您正在寻找的答案是here。不幸的是,这是一个涉及重塑的相当混乱的解决方案。

    【讨论】:

      【解决方案3】:

      如果m1m2 是2x2 复矩阵的一维数组,那么它们本质上具有(l,2,2) 的形状。所以最后两个轴上的矩阵乘法相当于将m1 的最后一个轴与m2 的倒数第二个轴的乘积相加。这正是np.dot 所做的:

      np.dot(m1,m2)
      

      或者,既然你有复数矩阵,也许你想先取m1 的复共轭。在这种情况下,请使用np.vdot

      PS。如果m1 是一个列表 2x2 复杂矩阵,那么也许看看你是否可以重新排列你的代码,使m1 从一开始就成为一个形状为(l,2,2) 的数组。

      如果不可能,列表推导

      [np.dot(m1[i],m2[i]) for i in range(l)]
      

      将比使用maplambda 更快,但执行l np.dots 会比在两个形状为(l,2,2) 的数组上执行np.dot 慢,如上所述。

      【讨论】:

        【解决方案4】:

        如果 m1 和 m2 是 2x2 复矩阵的一维数组,那么它们本质上具有 (l,2,2) 形状。所以最后两个轴上的矩阵乘法相当于将 m1 的最后一个轴与 m2 的倒数第二个轴的乘积相加。这正是 np.dot 所做的:

        但这不是 np.dot 所做的。

         a = numpy.array([numpy.diag([1, 2]), numpy.diag([2, 3]), numpy.diag([3, 4])])
        

        生成一个由 2×2 矩阵组成的 (3,2,2) 数组。但是,numpy.dot(a,a) 创建了 6 个矩阵,结果的形状是 (3, 2, 3, 2)。那不是我需要的。我需要的是一个包含 numpy.dot(a[0],a[0]), numpy.dot(a[1],a[1]), numpy.dot(a[2],a[2] ) ...

        [np.dot(m1[i],m2[i]) for i in range(l)]
        

        应该可以,但我还没有检查,它是否比 lambda 表达式的映射更快。

        干杯,

        v923z

        编辑:for 循环和地图以大致相同的速度运行。转换为 numpy.array 会消耗大量时间,但这两种方法都必须完成,因此这里没有任何收获。

        【讨论】:

          【解决方案5】:

          可能是这个问题太老了,但我仍在寻找答案。

          我试过这段代码

          a=np.asarray(range(1048576),dtype='complex');b=np.reshape(a//1024,(1024,1024));b=b+1J*b
          %timeit c=np.dot(b,b)
          %timeit d=np.einsum('ij, ki -> jk', b,b).T
          

          结果是:对于'点'

          10 loops, best of 3: 174 ms per loop
          

          对于'einsum'

          1 loops, best of 3: 4.51 s per loop
          

          我已经检查过 c 和 d 是否相同

          (c==d).all()
          True
          

          仍然是“点”是赢家,我仍在寻找更好的方法但没有成功

          【讨论】:

            猜你喜欢
            • 1970-01-01
            • 1970-01-01
            • 1970-01-01
            • 2015-10-12
            • 1970-01-01
            • 1970-01-01
            • 1970-01-01
            • 1970-01-01
            • 1970-01-01
            相关资源
            最近更新 更多