【发布时间】:2022-01-06 10:28:25
【问题描述】:
我正在尝试使用 Numba 来加速某些功能,特别是在给定三个角度的情况下执行 3D 旋转的功能,如下所示:
import numpy as np
from numba import jit
@jit(nopython=True)
def rotation_matrix(theta_x, theta_y, theta_z):
# Convert to radians. To ensure counter-clockwise (ccw) rotations, take
# negative of angles.
theta_x_rad = -np.radians(theta_x)
theta_y_rad = -np.radians(theta_y)
theta_z_rad = -np.radians(theta_z)
# Define rotation matrices (yaw, pitch, roll)
Rx = np.array([[1, 0,0],
[0, np.cos(theta_x_rad),-np.sin(theta_x_rad)],
[0, np.sin(theta_x_rad),np.cos(theta_x_rad) ]
])
Ry = np.array([[ np.cos(theta_y_rad), 0,np.sin(theta_y_rad)],
[ 0,1,0],
[-np.sin(theta_y_rad), 0,np.cos(theta_y_rad)]
])
Rz = np.array([[np.cos(theta_z_rad),-np.sin(theta_z_rad),0],
[np.sin(theta_z_rad),np.cos(theta_z_rad),0],
[0,0,1]
])
# Compute total rotation matrix
R = np.dot(Rz, np.dot( Ry, Rx ))
#
return R
函数比较简单,但是在Numba调用的时候,我尝试定义Rx的时候报错。看来Numba多维数组有问题(?)。我不确定如何修改它以便 Numba 可以使用它。任何帮助将不胜感激。
【问题讨论】:
-
我从来没有使用过 numba,(只有 Cython 可以加快速度),所以这可能有点不对劲,但你为什么要在这里首先使用它?你只是在做 numpy 操作,这些操作通常已经尽可能快了。如果您必须在循环中多次调用此函数,则最好创建一个形状为 (N, 3, 3) 的唯一 numpy 数组,其中所有 3x3 矩阵堆叠并在这些张量上使用
np.dot? -
这个函数被我脚本中的其他函数调用,但是当它被调用时,Numba 会抛出错误。问题不在于速度,而在于为什么 Numba 不喜欢这个功能。
-
尝试在
Rx, Ry和Rz的定义中将0替换为0.0,并将1替换为1.0。 -
好吧,抱歉,我真的帮不上忙。同样,我的经验仅限于 Cython,在那里我会避免使用任何非 python 对象(比如这里的 numpy),因为它们将不受支持或无论如何都不容易使用。您当然可以通过手动编写操作来重构您的代码(np.radians 只是
* pi/180,cos 和 sin 在math包中,3x3 矩阵的点积更长但易于编写),但它肯定会降低代码的可读性 -
或者您可以将
dtype=np.float32或dtype=np.float64传递给np.array,分别定义Rx, Ry或Rz。目前支持的类型列表,可以参考here.。
标签: python numpy multidimensional-array numba