【发布时间】:2009-12-01 19:13:31
【问题描述】:
两个n维向量u=[u1,u2,...un]和v=[v1,v2,...,vn]的点积由u1*v1 + u2*v2 + ... + un*vn给出。
posted yesterday 的一个问题鼓励我找到在 Python 中仅使用标准库、不使用第三方模块或 C/Fortran/C++ 调用来计算点积的最快方法。
我计时了四种不同的方法;到目前为止,最快的似乎是sum(starmap(mul,izip(v1,v2)))(其中starmap 和izip 来自itertools 模块)。
对于下面显示的代码,这些是经过的时间(以秒为单位,一百万次运行):
d0: 12.01215
d1: 11.76151
d2: 12.54092
d3: 09.58523
你能想出一种更快的方法吗?
import timeit # module with timing subroutines
import random # module to generate random numnbers
from itertools import imap,starmap,izip
from operator import mul
def v(N=50,min=-10,max=10):
"""Generates a random vector (in an array) of dimension N; the
values are integers in the range [min,max]."""
out = []
for k in range(N):
out.append(random.randint(min,max))
return out
def check(v1,v2):
if len(v1)!=len(v2):
raise ValueError,"the lenght of both arrays must be the same"
pass
def d0(v1,v2):
"""
d0 is Nominal approach:
multiply/add in a loop
"""
check(v1,v2)
out = 0
for k in range(len(v1)):
out += v1[k] * v2[k]
return out
def d1(v1,v2):
"""
d1 uses an imap (from itertools)
"""
check(v1,v2)
return sum(imap(mul,v1,v2))
def d2(v1,v2):
"""
d2 uses a conventional map
"""
check(v1,v2)
return sum(map(mul,v1,v2))
def d3(v1,v2):
"""
d3 uses a starmap (itertools) to apply the mul operator on an izipped (v1,v2)
"""
check(v1,v2)
return sum(starmap(mul,izip(v1,v2)))
# generate the test vectors
v1 = v()
v2 = v()
if __name__ == '__main__':
# Generate two test vectors of dimension N
t0 = timeit.Timer("d0(v1,v2)","from dot_product import d0,v1,v2")
t1 = timeit.Timer("d1(v1,v2)","from dot_product import d1,v1,v2")
t2 = timeit.Timer("d2(v1,v2)","from dot_product import d2,v1,v2")
t3 = timeit.Timer("d3(v1,v2)","from dot_product import d3,v1,v2")
print "d0 elapsed: ", t0.timeit()
print "d1 elapsed: ", t1.timeit()
print "d2 elapsed: ", t2.timeit()
print "d3 elapsed: ", t3.timeit()
请注意,文件名必须为dot_product.py,脚本才能运行;我在 Mac OS X 版本 10.5.8 上使用了 Python 2.5.1。
编辑:
我运行了 N=1000 的脚本,结果如下(以秒为单位,一百万次运行):
d0: 205.35457
d1: 208.13006
d2: 230.07463
d3: 155.29670
我想可以肯定地假设,实际上,选项三是最快的,而选项二是最慢的(在所提出的四个中)。
【问题讨论】:
-
@Arrieta:您可以通过将 'from dot_product' 替换为 'from main' 来删除文件名为 dot_product.py 的要求。
-
@unutbu:当然,我只是认为使用该名称保存文件以便快速运行比更改脚本更简单。谢谢。
-
我的结果是:d0 经过:13.4328830242 d1 经过:9.52215504646 d2 经过:10.1050257683 d3 经过:9.16764998436 请务必检查 d1 和 d3 之间的差异是否具有统计学意义。
-
@liori:没错。我正在运行 N=1000 的问题,预计会看到更大的差异。
-
如果你重复做一个点积,保持其中一个向量不变,动态编译方法可能值得研究。固定部分为 0 的所有项都可以全部去掉,固定部分为 1 的乘法可以去掉。