【发布时间】:2020-06-01 11:22:48
【问题描述】:
下面的代码运行太慢了。我尝试使用numpy.argwhere 而不是“if 语句”来加速代码,我得到了一个非常有效的结果,但它仍然很慢。我也试过numpy.frompyfunc 和numpy.vectorize 但我失败了。你有什么建议可以加快下面的代码?
import numpy as np
import time
time1 = time.time()
n = 1000000
k = 10000
velos = np.linspace(-1000, 1000, n)
line_centers = np.linspace(-1000, 1000, k)
weights = np.random.random_sample(k)
rvs = np.arange(-60, 60, 2)
m = len(rvs)
w = np.arange(10)
M = np.zeros((n, m))
for l, lc in enumerate(line_centers):
vi = velos - lc
for j in range(m - 1):
w = np.argwhere((vi < rvs[j + 1]) & (vi > rvs[j])).T[0]
M[w, j] = weights[l] * (rvs[j + 1] - vi[w]) / (rvs[j + 1] - rvs[j])
M[w, j + 1] = weights[l] * (vi[w] - rvs[j]) / (rvs[j + 1] - rvs[j])
time2 = time.time()
print(time2 - time1)
编辑:
数组M 的大小不正确。我修好了它。
【问题讨论】:
-
你可能想看看numba:numba.pydata.org
-
我来看看。谢谢。
-
如果你想使用 Numba,把所有东西写在一个简单的嵌套循环中(没有像 argwhere 这样的矢量化命令)
-
请注意,从您的代码中删除
argwhere(例如,将其替换为常量w=1,例如,将循环速度提高 20 倍。因此,如果您可以重写它,您的代码可能已经快很多了。嵌套循环在这里可能无关紧要。
标签: python python-3.x numpy vectorization