【发布时间】:2011-11-12 10:58:00
【问题描述】:
我有两个 2×2 复数矩阵数组,我想知道将它们相乘的最快方法是什么。 (我想对矩阵数组的元素做矩阵乘法。)目前我有
numpy.array(map(lambda i: numpy.dot(m1[i], m2[i]), range(l)))
但是还有比这更好的吗?
谢谢,
v923z
【问题讨论】:
标签: arrays performance matrix numpy multiplication
我有两个 2×2 复数矩阵数组,我想知道将它们相乘的最快方法是什么。 (我想对矩阵数组的元素做矩阵乘法。)目前我有
numpy.array(map(lambda i: numpy.dot(m1[i], m2[i]), range(l)))
但是还有比这更好的吗?
谢谢,
v923z
【问题讨论】:
标签: arrays performance matrix numpy multiplication
numpy.einsum 是这个问题的最佳解决方案,它在 DaveP 参考文献的底部被提及。代码很干净,很容易理解,并且比循环遍历数组并逐个进行乘法运算快一个数量级。下面是一些示例代码:
import numpy
l = 100
m1 = rand(l,2,2)
m2 = rand(l,2,2)
m3 = numpy.array(map(lambda i: numpy.dot(m1[i], m2[i]), range(l)))
m3e = numpy.einsum('lij,ljk->lik', m1, m2)
%timeit numpy.array(map(lambda i: numpy.dot(m1[i], m2[i]), range(l)))
%timeit numpy.einsum('lij,ljk->lik', m1, m2)
print np.all(m3==m3e)
以下是在 ipython 笔记本中运行时的返回值:
1000 次循环,最好的 3 次:每个循环 479 µs
10000 次循环,最好的 3 次:每个循环 48.9 µs
是的
【讨论】:
我认为您正在寻找的答案是here。不幸的是,这是一个涉及重塑的相当混乱的解决方案。
【讨论】:
如果m1 和m2 是2x2 复矩阵的一维数组,那么它们本质上具有(l,2,2) 的形状。所以最后两个轴上的矩阵乘法相当于将m1 的最后一个轴与m2 的倒数第二个轴的乘积相加。这正是np.dot 所做的:
np.dot(m1,m2)
或者,既然你有复数矩阵,也许你想先取m1 的复共轭。在这种情况下,请使用np.vdot。
PS。如果m1 是一个列表 2x2 复杂矩阵,那么也许看看你是否可以重新排列你的代码,使m1 从一开始就成为一个形状为(l,2,2) 的数组。
如果不可能,列表推导
[np.dot(m1[i],m2[i]) for i in range(l)]
将比使用map 和lambda 更快,但执行l np.dots 会比在两个形状为(l,2,2) 的数组上执行np.dot 慢,如上所述。
【讨论】:
如果 m1 和 m2 是 2x2 复矩阵的一维数组,那么它们本质上具有 (l,2,2) 形状。所以最后两个轴上的矩阵乘法相当于将 m1 的最后一个轴与 m2 的倒数第二个轴的乘积相加。这正是 np.dot 所做的:
但这不是 np.dot 所做的。
a = numpy.array([numpy.diag([1, 2]), numpy.diag([2, 3]), numpy.diag([3, 4])])
生成一个由 2×2 矩阵组成的 (3,2,2) 数组。但是,numpy.dot(a,a) 创建了 6 个矩阵,结果的形状是 (3, 2, 3, 2)。那不是我需要的。我需要的是一个包含 numpy.dot(a[0],a[0]), numpy.dot(a[1],a[1]), numpy.dot(a[2],a[2] ) ...
[np.dot(m1[i],m2[i]) for i in range(l)]
应该可以,但我还没有检查,它是否比 lambda 表达式的映射更快。
干杯,
v923z
编辑:for 循环和地图以大致相同的速度运行。转换为 numpy.array 会消耗大量时间,但这两种方法都必须完成,因此这里没有任何收获。
【讨论】:
可能是这个问题太老了,但我仍在寻找答案。
我试过这段代码
a=np.asarray(range(1048576),dtype='complex');b=np.reshape(a//1024,(1024,1024));b=b+1J*b
%timeit c=np.dot(b,b)
%timeit d=np.einsum('ij, ki -> jk', b,b).T
结果是:对于'点'
10 loops, best of 3: 174 ms per loop
对于'einsum'
1 loops, best of 3: 4.51 s per loop
我已经检查过 c 和 d 是否相同
(c==d).all()
True
仍然是“点”是赢家,我仍在寻找更好的方法但没有成功
【讨论】: