【发布时间】:2014-04-10 16:13:32
【问题描述】:
我有大数据集,代表 220 维周期空间中的 120 万个点(x 变化 fom (-pi,pi))...(矩阵:1.2M x 220)。
我想在考虑周期性的情况下计算这些点之间的距离直方图。我已经用 python 编写了一些代码,但是对于我的测试用例来说它仍然运行得很慢(我什至没有尝试在整个集合上运行它......)。
你可以看看并帮我做一些调整吗?
非常感谢任何建议和cmets。
import numpy as np
# 1000x220 test set (-pi,pi)
d=np.random.random((1000, 220))*2*np.pi-np.pi
# calculating theoretical limit on the histogram range, max distance between
# two points can be pi in each dimension
m=np.zeros(np.shape(d)[1])+np.pi
m_=np.sqrt(np.sum(m**2))
# hist range is from 0 to mm
mm=np.floor(m_)
bins=mm/0.01
m=np.zeros(bins)
# proper calculations
import time
start_time = time.time()
for i in range(np.shape(d)[0]):
diff=d[:-(i+1),:]-d[i+1:,:]
diff=np.absolute(diff)
adiff=diff-np.pi
diff=np.pi-np.absolute(adiff)
s=np.sqrt(np.einsum('ij,ij->i', diff,diff))
m+=np.histogram(s,range=(0,mm),bins=bins)[0]
print time.time() - start_time
【问题讨论】:
-
您是否对代码进行了概要分析以查看其大部分时间都花在了哪些地方?
-
在这种情况下,您计算 diff[abs(diff)>np.pi] 两次。尝试做一次,即将它保存在一个变量中并使用它。
-
np.power似乎消耗了相当长的时间。在快速测试中,使用diff * diff的速度明显更快。更快一点仍然是将np.sum(diff*diff,1)替换为np.einsum('ij,ij->i', diff, diff)。 -
另外,我会将 einsum 行及其上方的行 (diff2) 替换为
np.sqrt(np.einsum('ij,ij->i', diff, diff))。 -
并且只使用
m += np.histogram(...)[0]也会节省一点时间(不会先将数组复制到h,只需添加到位)
标签: python numpy scipy distance