【问题标题】:Broadcast rotation matrices multiplication广播旋转矩阵乘法
【发布时间】:2017-12-19 17:17:09
【问题描述】:

标有# <----的那一行怎么做更直接?

程序中x的每一行是一个点的坐标,rot_mat[0]rot_mat[1]是两个旋转矩阵。程序按每个旋转矩阵旋转x

改变每个旋转矩阵和坐标之间的乘法顺序很好,如果它使事情变得更简单的话。我想要x 的每一行或代表一个点坐标的结果。

结果应与检查相符。

程序:

# Rotation of coordinates of 4 points by 
# each of the 2 rotation matrices.
import numpy as np
from scipy.stats import special_ortho_group
rot_mats = special_ortho_group.rvs(dim=3, size=2)  # 2 x 3 x 3
x = np.arange(12).reshape(4, 3)
result = np.dot(rot_mats, x.T).transpose((0, 2, 1))  # <----
print("---- result ----")
print(result)
print("---- check ----")
print(np.dot(x, rot_mats[0].T))
print(np.dot(x, rot_mats[1].T))

结果:

---- result ----
[[[  0.20382264   1.15744672   1.90230739]
  [ -2.68064533   3.71537598   5.38610452]
  [ -5.56511329   6.27330525   8.86990165]
  [ -8.44958126   8.83123451  12.35369878]]

 [[  1.86544623   0.53905202  -1.10884323]
  [  5.59236544  -1.62845022  -4.00918928]
  [  9.31928465  -3.79595246  -6.90953533]
  [ 13.04620386  -5.9634547   -9.80988139]]]
---- check ----
[[  0.20382264   1.15744672   1.90230739]
 [ -2.68064533   3.71537598   5.38610452]
 [ -5.56511329   6.27330525   8.86990165]
 [ -8.44958126   8.83123451  12.35369878]]
[[  1.86544623   0.53905202  -1.10884323]
 [  5.59236544  -1.62845022  -4.00918928]
 [  9.31928465  -3.79595246  -6.90953533]
 [ 13.04620386  -5.9634547   -9.80988139]]

【问题讨论】:

  • 现在有什么问题?
  • 希望我能摆脱transpose((0, 2, 1))
  • 转置不是瓶颈,但np.dot 是。

标签: python numpy


【解决方案1】:

使用np.tensordot 进行涉及tensors 的乘法 -

np.tensordot(rot_mats, x, axes=((2),(1))).swapaxes(1,2)

这里有一些时间可以说服自己为什么 tensordottensors 更有效 -

In [163]: rot_mats = np.random.rand(20,30,30)
     ...: x = np.random.rand(40,30)

# With numpy.dot
In [164]: %timeit np.dot(rot_mats, x.T).transpose((0, 2, 1))
1000 loops, best of 3: 670 µs per loop

# With numpy.tensordot
In [165]: %timeit np.tensordot(rot_mats, x, axes=((2),(1))).swapaxes(1,2)
10000 loops, best of 3: 75.7 µs per loop

In [166]: rot_mats = np.random.rand(200,300,300)
     ...: x = np.random.rand(400,300)

# With numpy.dot
In [167]: %timeit np.dot(rot_mats, x.T).transpose((0, 2, 1))
1 loop, best of 3: 1.82 s per loop

# With numpy.tensordot
In [168]: %timeit np.tensordot(rot_mats, x, axes=((2),(1))).swapaxes(1,2)
10 loops, best of 3: 185 ms per loop

【讨论】:

  • swapaxes 是否提供新数组而不是视图?
  • @Rzu 给出一个看法。
  • @Rzu 对于 NumPy >= 1.10.0,如果 a 是 ndarray,则返回 a 的视图;否则创建一个新数组。对于早期的 NumPy 版本,仅当轴的顺序发生更改时才返回 a 的视图,否则返回输入数组。 - source
  • 不要认为它可以是一个视图,因为轴的顺序从 0、1、2 更改为 0、2、1。简单的测试是在交换轴之前和之后解开所有内容,并且看看他们是否给出相同的结果。
  • @Rzu 也许您正在考虑其他一些观点。我在谈论它是否是副本。更多信息 - scipy-cookbook.readthedocs.io/items/ViewsVsCopies.html
猜你喜欢
  • 2015-01-07
  • 2019-05-06
  • 1970-01-01
  • 1970-01-01
  • 2020-07-21
  • 1970-01-01
  • 2013-03-26
  • 1970-01-01
  • 2018-04-11
相关资源
最近更新 更多