【问题标题】:Numba Rotation Matrix with Multidimensional Arrays具有多维数组的 Numba 旋转矩阵
【发布时间】:2022-01-06 10:28:25
【问题描述】:

我正在尝试使用 Numba 来加速某些功能,特别是在给定三个角度的情况下执行 3D 旋转的功能,如下所示:

import numpy as np
from numba import jit

@jit(nopython=True)
def rotation_matrix(theta_x, theta_y, theta_z):
    # Convert to radians. To ensure counter-clockwise (ccw) rotations, take
    # negative of angles.
    theta_x_rad = -np.radians(theta_x)
    theta_y_rad = -np.radians(theta_y)
    theta_z_rad = -np.radians(theta_z)
    # Define rotation matrices (yaw, pitch, roll)
    Rx = np.array([[1, 0,0],
                    [0, np.cos(theta_x_rad),-np.sin(theta_x_rad)],
                    [0, np.sin(theta_x_rad),np.cos(theta_x_rad) ]
                    ])


    Ry = np.array([[ np.cos(theta_y_rad), 0,np.sin(theta_y_rad)],
                    [ 0,1,0],
                    [-np.sin(theta_y_rad), 0,np.cos(theta_y_rad)]
                    ])


    Rz = np.array([[np.cos(theta_z_rad),-np.sin(theta_z_rad),0],
                    [np.sin(theta_z_rad),np.cos(theta_z_rad),0],
                    [0,0,1]
                    ])


    # Compute total rotation matrix
    R  = np.dot(Rz, np.dot( Ry, Rx ))
    #
    return R

函数比较简单,但是在Numba调用的时候,我尝试定义Rx的时候报错。看来Numba多维数组有问题(?)。我不确定如何修改它以便 Numba 可以使用它。任何帮助将不胜感激。

【问题讨论】:

  • 我从来没有使用过 numba,(只有 Cython 可以加快速度),所以这可能有点不对劲,但你为什么要在这里首先使用它?你只是在做 numpy 操作,这些操作通常已经尽可能快了。如果您必须在循环中多次调用此函数,则最好创建一个形状为 (N, 3, 3) 的唯一 numpy 数组,其中所有 3x3 矩阵堆叠并在这些张量上使用 np.dot ?
  • 这个函数被我脚本中的其他函数调用,但是当它被调用时,Numba 会抛出错误。问题不在于速度,而在于为什么 Numba 不喜欢这个功能。
  • 尝试在Rx, RyRz 的定义中将0 替换为0.0,并将1 替换为1.0
  • 好吧,抱歉,我真的帮不上忙。同样,我的经验仅限于 Cython,在那里我会避免使用任何非 python 对象(比如这里的 numpy),因为它们将不受支持或无论如何都不容易使用。您当然可以通过手动编写操作来重构您的代码(np.radians 只是 * pi/180,cos 和 sin 在 math 包中,3x3 矩阵的点积更长但易于编写),但它肯定会降低代码的可读性
  • 或者您可以将dtype=np.float32dtype=np.float64 传递给np.array,分别定义Rx, RyRz。目前支持的类型列表,可以参考here.

标签: python numpy multidimensional-array numba


【解决方案1】:

问题来自整数和浮点类型值之间的混合。 Numba 尝试定义数组的类型,发现 [1, 0, 0] 是一个整数列表,但整个数组同时使用整数列表和浮点数列表进行初始化。由于整体类型不明确,类型推断被混淆并引发错误。你可以写 1.00.0 而不是 10 来解决这个问题。更一般地说,指定数组的dtype 通常是一种很好的做法,尤其是在Numba 中,因为类型推断

如果你想在函数第一次调用时避免在运行时出现编译错误,那么你可以精确的参数类型。请注意,您可以使用njit 代替nopython=True(更短)。生成的装饰器应该是@njit('(float64, float64, float64)')

【讨论】:

  • 从 int 到 float 的转换工作。谢谢你的解释。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2014-12-16
  • 2016-07-18
  • 1970-01-01
相关资源
最近更新 更多