【问题标题】:Reshaping order in PyTorch - Fortran-like index orderingPyTorch 中的重塑顺序 - 类似 Fortran 的索引排序
【发布时间】:2021-01-05 15:56:43
【问题描述】:

在numpy中,有一个用于重塑数组的排序功能,默认为C,但你可以指定其他排序,如F:

a = np.arange(6).reshape((3, 2))
f = np.reshape(a, (2, 3), order='F') # Fortran-like index ordering
c = np.reshape(a, (2, 3)) 
print('a= \n', a)
print('f= \n', b)
print('c= \n', c)

结果:

a= 
 [[0 1]
 [2 3]
 [4 5]]
f= 
 [[0 4 3]
 [2 1 5]]
c= 
 [[0 1 2]
 [3 4 5]]

torch.reshape 或 tensor.view 中没有用于按 F 顺序整形的选项。 有什么办法可以在 PyTorch 中重塑 F 顺序?我需要一切都在 PyTorch 中。

【问题讨论】:

  • 在 NumPy 中重塑数组,然后将其转换为 PyTorch 张量?
  • 谢谢,但我需要所有东西都在 PyTorch 中,因为我想在我的 DNN 中将该函数用作自定义损失函数。

标签: numpy view pytorch reshape


【解决方案1】:

我不认为 pytorch 对此有内置支持。也就是说,您可以使用Tensor.permute 实现所需的结果。不幸的是,我怀疑这会非常有效,因为 AFAIK permute 在内部制作了张量的副本。

def reshape_fortran(x, shape):
    if len(x.shape) > 0:
        x = x.permute(*reversed(range(len(x.shape))))
    return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))

示例用法:

a = torch.arange(6).reshape(3, 2)
f = reshape_fortran(a, (2, 3))
c = a.reshape(2, 3)

导致

a = 
tensor([[0, 1],
        [2, 3],
        [4, 5]])
f =
tensor([[0, 4, 3],
        [2, 1, 5]])
c =
tensor([[0, 1, 2],
        [3, 4, 5]])

【讨论】:

  • 感谢 Jodge,这是一个非常有效的答案。我通过连续转置、展平和重新排列来实现它,虽然实用但效率不高。感谢您的回复。
【解决方案2】:

jodag's answer 工作得非常好,我对性能进行了进一步的测试。与内置 reshape 相比,基于 permute 的重塑需要大约 10 倍的计算时间。但是在numpy中,像reshape这样的内置Fortran也需要10倍的计算时间。所以这个方法效率很高:)

这是 i9-10900X 和 RTX2080Ti 的测试代码和结果:

import numpy as np
import torch
import time

dim1 = 40
dim2 = 50
dim3 = 5

def reshape_fortran(x, shape):
    if len(x.shape) > 0:
        x = x.permute(*reversed(range(len(x.shape))))
    return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))

torch.cuda.set_device(0)
device = torch.device('cuda')

x = [torch.from_numpy(np.random.rand(dim1, dim2)).to(device) for _ in range(100)]
xx = [torch.from_numpy(np.random.rand(dim1, dim2)).to(device) for _ in range(100)]
for i in range(100):
    y = x[i].reshape([dim2, dim1])
# c reshape
t0 = time.time()
for i in range(100):
    y = xx[i].reshape([dim2, dim3, -1])
t1 = time.time()

# fortran reshape
for i in range(100):
    yy = reshape_fortran(xx[i], [dim2, dim3, -1])
t2 = time.time()

print(f'torch build-in reshape: {(t1 - t0)/100} s')
print(f'torch permute reshape: {(t2 - t1)/100} s')

x = [np.random.rand(dim1, dim2) for _ in range(100)]
xx = [np.random.rand(dim1, dim2) for _ in range(100)]
for i in range(100):
    y = x[i].reshape([dim2, dim3, -1])
t0 = time.time()
for i in range(100):
    yy = xx[i].reshape([dim2, dim3, -1])
t1 = time.time()
for i in range(100):
    yyy = xx[i].reshape([dim2, dim3, -1], order='F')
t2 = time.time()

print(f'numpy C reshape: {(t1 - t0)/100} s')
print(f'numpy F reshape: {(t2 - t1)/100} s')
torch build-in reshape: 9.72747802734375e-07 s
torch permute reshape: 1.1897087097167968e-05 s
numpy C reshape: 3.0517578125e-07 s
numpy F reshape: 2.474784851074219e-06 s

【讨论】:

    猜你喜欢
    • 2013-08-30
    • 1970-01-01
    • 2018-08-25
    • 1970-01-01
    • 1970-01-01
    • 2013-08-19
    • 1970-01-01
    • 1970-01-01
    • 2014-05-21
    相关资源
    最近更新 更多