【发布时间】:2022-01-12 07:04:49
【问题描述】:
这是标准的批量矩阵乘法:
import torch
a = torch.arange(12, dtype=torch.float).view(2,3,2)
b = torch.arange(12, dtype=torch.float).view(2,3,2) - 1
c = a.matmul(b.transpose(-1,-2))
a,b,c
>>
(tensor([[[ 0., 1.],
[ 2., 3.],
[ 4., 5.]],
[[ 6., 7.],
[ 8., 9.],
[10., 11.]]]),
tensor([[[-1., 0.],
[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.],
[ 9., 10.]]]),
tensor([[[ 0., 2., 4.],
[ -2., 8., 18.],
[ -4., 14., 32.]],
[[ 72., 98., 124.],
[ 94., 128., 162.],
[116., 158., 200.]]]))
这是我拥有的:
e = a.view(6,2)
f = b.view(6,2)
g = e.matmul(f.transpose(-1,-2))
e,f,g
>>
(tensor([[ 0., 1.],
[ 2., 3.],
[ 4., 5.],
[ 6., 7.],
[ 8., 9.],
[10., 11.]]),
tensor([[-1., 0.],
[ 1., 2.],
[ 3., 4.],
[ 5., 6.],
[ 7., 8.],
[ 9., 10.]]),
tensor([[ 0., 2., 4., 6., 8., 10.],
[ -2., 8., 18., 28., 38., 48.],
[ -4., 14., 32., 50., 68., 86.],
[ -6., 20., 46., 72., 98., 124.],
[ -8., 26., 60., 94., 128., 162.],
[-10., 32., 74., 116., 158., 200.]]))
很明显g 覆盖了c。我想知道是否有一种有效的方法可以从g 检索/切片c。请注意,这种检索/切片方法应该很好地推广到a 和b 的任何形状.
【问题讨论】:
-
我觉得我的回答更好。