【发布时间】:2011-06-25 03:48:18
【问题描述】:
在应用程序的核心(用 Python 编写并使用 NumPy)我需要旋转一个四阶张量。实际上,我需要多次旋转很多张量,这是我的瓶颈。我的涉及八个嵌套循环的幼稚实现(如下)似乎很慢,但我看不到利用 NumPy 的矩阵运算并希望加快速度的方法。我觉得我应该使用np.tensordot,但我不知道该怎么做。
在数学上,旋转张量 T' 的元素由下式给出: T'ijkl = Σ gia gjb gkc gld Tabcd 总和超过右侧的重复索引。 T 和 Tprime 是 3*3*3*3 NumPy 数组,旋转矩阵 g 是 3*3 NumPy 数组。我的慢速实现(每次调用约 0.04 秒)如下。
#!/usr/bin/env python
import numpy as np
def rotT(T, g):
Tprime = np.zeros((3,3,3,3))
for i in range(3):
for j in range(3):
for k in range(3):
for l in range(3):
for ii in range(3):
for jj in range(3):
for kk in range(3):
for ll in range(3):
gg = g[ii,i]*g[jj,j]*g[kk,k]*g[ll,l]
Tprime[i,j,k,l] = Tprime[i,j,k,l] + \
gg*T[ii,jj,kk,ll]
return Tprime
if __name__ == "__main__":
T = np.array([[[[ 4.66533067e+01, 5.84985000e-02, -5.37671310e-01],
[ 5.84985000e-02, 1.56722231e+01, 2.32831900e-02],
[ -5.37671310e-01, 2.32831900e-02, 1.33399259e+01]],
[[ 4.60051700e-02, 1.54658176e+01, 2.19568200e-02],
[ 1.54658176e+01, -5.18223500e-02, -1.52814920e-01],
[ 2.19568200e-02, -1.52814920e-01, -2.43874100e-02]],
[[ -5.35577630e-01, 1.95558600e-02, 1.31108757e+01],
[ 1.95558600e-02, -1.51342210e-01, -6.67615000e-03],
[ 1.31108757e+01, -6.67615000e-03, 6.90486240e-01]]],
[[[ 4.60051700e-02, 1.54658176e+01, 2.19568200e-02],
[ 1.54658176e+01, -5.18223500e-02, -1.52814920e-01],
[ 2.19568200e-02, -1.52814920e-01, -2.43874100e-02]],
[[ 1.57414726e+01, -3.86167500e-02, -1.55971950e-01],
[ -3.86167500e-02, 4.65601977e+01, -3.57741000e-02],
[ -1.55971950e-01, -3.57741000e-02, 1.34215636e+01]],
[[ 2.58256300e-02, -1.49072770e-01, -7.38843000e-03],
[ -1.49072770e-01, -3.63410500e-02, 1.32039847e+01],
[ -7.38843000e-03, 1.32039847e+01, 1.38172700e-02]]],
[[[ -5.35577630e-01, 1.95558600e-02, 1.31108757e+01],
[ 1.95558600e-02, -1.51342210e-01, -6.67615000e-03],
[ 1.31108757e+01, -6.67615000e-03, 6.90486240e-01]],
[[ 2.58256300e-02, -1.49072770e-01, -7.38843000e-03],
[ -1.49072770e-01, -3.63410500e-02, 1.32039847e+01],
[ -7.38843000e-03, 1.32039847e+01, 1.38172700e-02]],
[[ 1.33639532e+01, -1.26331100e-02, 6.84650400e-01],
[ -1.26331100e-02, 1.34222177e+01, 1.67851800e-02],
[ 6.84650400e-01, 1.67851800e-02, 4.89151396e+01]]]])
g = np.array([[ 0.79389393, 0.54184237, 0.27593346],
[-0.59925749, 0.62028664, 0.50609776],
[ 0.10306737, -0.56714313, 0.8171449 ]])
for i in range(100):
Tprime = rotT(T,g)
有没有办法让它更快?让代码泛化到其他张量等级会很有用,但不太重要。
【问题讨论】:
-
而且,如果很明显在 numpy 或 scipy 中更快地做到这一点并不容易,我将推出一个 Fortran 扩展模块,看看它是如何执行的。
-
如果一切都失败了,你可以使用 Cython。据说是plays well with numpy。
-
虽然我相当确定有一种方法可以减少 numpy 中的嵌套循环(不过我没有立即看到),但正如@delnan 所说,您当前的代码是主要候选者对于 Cython....
标签: python optimization numpy rotation scipy