我的跟踪使用 3 个过滤器,rot90、np.where、np.sum 和 np.multiply。我不确定哪种基准测试方法更合理。如果您不考虑创建过滤器的时间,则速度大约快 4 倍。
# Each filter basically does what `op` tries to achieve in a loop
filter1 = np.array([[0, 1 ,0, 0, 0],
[1, -4, 1, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]) /5.
filter2 = np.array([[0, 0 ,1, 0, 0],
[0, 1, -4, 1, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]) /5.
filter3 = np.array([[0, 0 ,0, 0, 0],
[0, 0, 1, 0, 0],
[0, 1, -4, 1, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 0]]) /5.
# only loop over the center of the matrix, a
center = np.array([[0, 0, 0, 0, 0],
[0, 1, 1, 1, 0],
[0, 1, 1, 1, 0],
[0, 1, 1, 1, 0],
[0, 0, 0, 0, 0]])
filter1 和 filter2 可以旋转以分别表示 4 个过滤器。
filter1_90_rot = np.rot90(filter1, k=1)
filter1_180_rot = np.rot90(filter1, k=2)
filter1_270_rot = np.rot90(filter1, k=3)
filter2_90_rot = np.rot90(filter2, k=1)
filter2_180_rot = np.rot90(filter2, k=2)
filter2_270_rot = np.rot90(filter2, k=3)
# Based on different index from `a` return different filter
filter_dict = {
(1,1): filter1,
(3,1): filter1_90_rot,
(3,3): filter1_180_rot,
(1,3): filter1_270_rot,
(1,2): filter2,
(2,1): filter2_90_rot,
(3,2): filter2_180_rot,
(2,3): filter2_270_rot,
(2,2): filter3
}
主要功能
def get_new_a(a):
x, y = np.where(((a > 3) * center) > 0) # find pairs that match the condition
return a + np.sum(np.multiply(filter_dict[i, j], a[i,j])
for (i, j) in zip(x,y))
注意:似乎存在一些数字错误,例如 np.equal() 在我的结果和 OP 之间大多会返回 False,而 np.close() 会返回 true。
计时结果
def op():
temp_a = np.copy(a)
for i in range(1,a.shape[0]-1):
for j in range(1,a.shape[1]-1):
if a[i,j] > 3:
temp_a[i+1,j] += a[i,j] / 5.
temp_a[i-1,j] += a[i,j] / 5.
temp_a[i,j+1] += a[i,j] / 5.
temp_a[i,j-1] += a[i,j] / 5.
temp_a[i,j] -= a[i,j] * 4. / 5.
a2 = np.copy(temp_a)
%timeit op()
167 µs ± 2.72 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit get_new_a(a):
37.2 µs ± 2.68 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
再次注意,我们忽略了创建过滤器的时间,因为我认为这将是一次性的事情。 如果您确实想包括创建过滤器的时间,它大约快两倍。您可能认为这不公平,因为 op 的方法包含两个 np.copy。我认为 op 方法的瓶颈是 for 循环。
参考:
numpy.multiply 在两个矩阵之间进行元素乘法。
np.rot90 为我们进行旋转。 k 是一个参数,您可以决定旋转多少次。
np.isclose 可以使用这个函数来检查两个矩阵是否在你可以定义的某个误差范围内接近。