jodag's answer 工作得非常好,我对性能进行了进一步的测试。与内置 reshape 相比,基于 permute 的重塑需要大约 10 倍的计算时间。但是在numpy中,像reshape这样的内置Fortran也需要10倍的计算时间。所以这个方法效率很高:)
这是 i9-10900X 和 RTX2080Ti 的测试代码和结果:
import numpy as np
import torch
import time
dim1 = 40
dim2 = 50
dim3 = 5
def reshape_fortran(x, shape):
if len(x.shape) > 0:
x = x.permute(*reversed(range(len(x.shape))))
return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))
torch.cuda.set_device(0)
device = torch.device('cuda')
x = [torch.from_numpy(np.random.rand(dim1, dim2)).to(device) for _ in range(100)]
xx = [torch.from_numpy(np.random.rand(dim1, dim2)).to(device) for _ in range(100)]
for i in range(100):
y = x[i].reshape([dim2, dim1])
# c reshape
t0 = time.time()
for i in range(100):
y = xx[i].reshape([dim2, dim3, -1])
t1 = time.time()
# fortran reshape
for i in range(100):
yy = reshape_fortran(xx[i], [dim2, dim3, -1])
t2 = time.time()
print(f'torch build-in reshape: {(t1 - t0)/100} s')
print(f'torch permute reshape: {(t2 - t1)/100} s')
x = [np.random.rand(dim1, dim2) for _ in range(100)]
xx = [np.random.rand(dim1, dim2) for _ in range(100)]
for i in range(100):
y = x[i].reshape([dim2, dim3, -1])
t0 = time.time()
for i in range(100):
yy = xx[i].reshape([dim2, dim3, -1])
t1 = time.time()
for i in range(100):
yyy = xx[i].reshape([dim2, dim3, -1], order='F')
t2 = time.time()
print(f'numpy C reshape: {(t1 - t0)/100} s')
print(f'numpy F reshape: {(t2 - t1)/100} s')
torch build-in reshape: 9.72747802734375e-07 s
torch permute reshape: 1.1897087097167968e-05 s
numpy C reshape: 3.0517578125e-07 s
numpy F reshape: 2.474784851074219e-06 s