说“Matlab 总是比 NumPy 快”是错误的,反之亦然
反之亦然。他们的表现通常是可比的。使用 NumPy 时,要变得更好
性能你必须牢记 NumPy 的速度来自于调用
用 C/C++/Fortran 编写的底层函数。申请时效果很好
这些函数到整个数组。一般来说,在 Python 循环中对较小的数组或标量调用这些 NumPy 函数时,性能会较差。
你问的 Python 循环有什么问题?通过 Python 循环的每次迭代都是
调用next 方法。每次使用[] 索引都是调用
__getitem__ 方法。每个+= 都是对__iadd__ 的调用。每个虚线属性
查找(例如np.dot)涉及函数调用。那些函数调用
加起来是速度的重大障碍。这些钩子给了 Python
表现力——字符串的索引意味着与索引不同的东西
例如,对于 dicts。相同的语法,不同的含义。魔术是通过为对象提供不同的__getitem__ 方法来实现的。
但这种表现力是以速度为代价的。所以当你不需要所有
那种动态的表现力,为了得到更好的表现,尽量把自己限制在
NumPy 函数调用整个数组。
所以,删除 for 循环;尽可能使用“矢量化”方程。例如,而不是
for i in range(m):
delta3 = -(x[i,:]-a3[i,:])*a3[i,:]* (1 - a3[i,:])
您可以同时为每个i 计算delta3:
delta3 = -(x-a3)*a3*(1-a3)
而在for-loop delta3 是一个向量,使用向量化方程delta3 是一个矩阵。
for-loop 中的一些计算不依赖于i,因此应该在循环之外提升。例如,sum2 看起来像一个常量:
sum2 = sparse.beta*(-float(sparse.rho)/rhoest + float(1.0 - sparse.rho) / (1.0 - rhoest) )
这是一个可运行的示例,其中包含您的代码 (orig) 的替代实现 (alt)。
我的 timeit 基准测试显示 速度提高了 6.8 倍:
In [52]: %timeit orig()
1 loops, best of 3: 495 ms per loop
In [53]: %timeit alt()
10 loops, best of 3: 72.6 ms per loop
import numpy as np
class Bunch(object):
""" http://code.activestate.com/recipes/52308 """
def __init__(self, **kwds):
self.__dict__.update(kwds)
m, n, p = 10 ** 4, 64, 25
sparse = Bunch(
theta1=np.random.random((p, n)),
theta2=np.random.random((n, p)),
b1=np.random.random((p, 1)),
b2=np.random.random((n, 1)),
)
x = np.random.random((m, n))
a3 = np.random.random((m, n))
a2 = np.random.random((m, p))
a1 = np.random.random((m, n))
sum2 = np.random.random((p, ))
sum2 = sum2[:, np.newaxis]
def orig():
partial_j1 = np.zeros(sparse.theta1.shape)
partial_j2 = np.zeros(sparse.theta2.shape)
partial_b1 = np.zeros(sparse.b1.shape)
partial_b2 = np.zeros(sparse.b2.shape)
delta3t = (-(x - a3) * a3 * (1 - a3)).T
for i in range(m):
delta3 = delta3t[:, i:(i + 1)]
sum1 = np.dot(sparse.theta2.T, delta3)
delta2 = (sum1 + sum2) * a2[i:(i + 1), :].T * (1 - a2[i:(i + 1), :].T)
partial_j1 += np.dot(delta2, a1[i:(i + 1), :])
partial_j2 += np.dot(delta3, a2[i:(i + 1), :])
partial_b1 += delta2
partial_b2 += delta3
# delta3: (64, 1)
# sum1: (25, 1)
# delta2: (25, 1)
# a1[i:(i+1),:]: (1, 64)
# partial_j1: (25, 64)
# partial_j2: (64, 25)
# partial_b1: (25, 1)
# partial_b2: (64, 1)
# a2[i:(i+1),:]: (1, 25)
return partial_j1, partial_j2, partial_b1, partial_b2
def alt():
delta3 = (-(x - a3) * a3 * (1 - a3)).T
sum1 = np.dot(sparse.theta2.T, delta3)
delta2 = (sum1 + sum2) * a2.T * (1 - a2.T)
# delta3: (64, 10000)
# sum1: (25, 10000)
# delta2: (25, 10000)
# a1: (10000, 64)
# a2: (10000, 25)
partial_j1 = np.dot(delta2, a1)
partial_j2 = np.dot(delta3, a2)
partial_b1 = delta2.sum(axis=1)
partial_b2 = delta3.sum(axis=1)
return partial_j1, partial_j2, partial_b1, partial_b2
answer = orig()
result = alt()
for a, r in zip(answer, result):
try:
assert np.allclose(np.squeeze(a), r)
except AssertionError:
print(a.shape)
print(r.shape)
raise
提示:请注意,我在 cmets 中留下了所有中间数组的形状。了解数组的形状有助于我理解您的代码在做什么。数组的形状可以帮助引导您使用正确的 NumPy 函数。或者至少,注意形状可以帮助您了解操作是否合理。例如,当您计算时
np.dot(A, B)
和A.shape = (n, m) 和B.shape = (m, p),那么np.dot(A, B) 将是一个形状为(n, p) 的数组。
它可以帮助以 C_CONTIGUOUS 顺序构建数组(至少,如果使用 np.dot)。这样做可能会提高 3 倍的速度:
下面,x 与 xf 相同,除了 x 是 C_CONTIGUOUS 和
xf 是 F_CONTIGUOUS——y 和 yf 的关系相同。
import numpy as np
m, n, p = 10 ** 4, 64, 25
x = np.random.random((n, m))
xf = np.asarray(x, order='F')
y = np.random.random((m, n))
yf = np.asarray(y, order='F')
assert np.allclose(x, xf)
assert np.allclose(y, yf)
assert np.allclose(np.dot(x, y), np.dot(xf, y))
assert np.allclose(np.dot(x, y), np.dot(xf, yf))
%timeit 基准测试显示速度差异:
In [50]: %timeit np.dot(x, y)
100 loops, best of 3: 12.9 ms per loop
In [51]: %timeit np.dot(xf, y)
10 loops, best of 3: 27.7 ms per loop
In [56]: %timeit np.dot(x, yf)
10 loops, best of 3: 21.8 ms per loop
In [53]: %timeit np.dot(xf, yf)
10 loops, best of 3: 33.3 ms per loop
关于 Python 中的基准测试:
It can be misleading 使用 time.time() 调用对的差异来对 Python 中的代码速度进行基准测试。
您需要多次重复测量。最好禁用自动垃圾收集器。测量大的时间跨度(例如至少 10 秒的重复次数)也很重要,以避免由于时钟计时器分辨率差而导致的错误,并减少 time.time 调用开销的重要性。 Python 不是自己编写所有代码,而是为您提供timeit module。我基本上是用它来计时代码片段,只是为了方便起见,我通过IPython terminal 调用它。
我不确定这是否会影响您的基准测试,但请注意它可能会有所作为。在question I linked to 中,根据time.time,两段代码相差1.7 倍,而使用timeit 的基准测试显示这两条代码运行的时间基本相同。