【发布时间】:2021-05-10 04:13:41
【问题描述】:
我从以下 Pytorch 代码中收到此错误:
RuntimeError:梯度计算所需的变量之一已被就地操作修改:[torch.DoubleTensor [3]] 版本为 10;预计版本 9。
正如所见,代码没有就地操作。
import torch
device = torch.device('cpu')
class MesNet(torch.nn.Module):
def __init__(self):
super(MesNet, self).__init__()
self.cov_lin = torch.nn.Sequential(torch.nn.Linear(6, 5)).double()
def forward(self, u):
z_cov = self.cov_lin(u.transpose(0, 2).squeeze(-1))
return z_cov
class UpdateModel(torch.nn.Module):
def __init__(self):
torch.nn.Module.__init__(self)
self.P_dim = 18
self.Id3 = torch.eye(3).double()
def run_KF(self):
N = 10
u = torch.randn(N, 6).double()
v = torch.zeros(N, 3).double()
model = MesNet()
measurements_covs_l = model(u.t().unsqueeze(0))
# remember to remove this afterwards
torch.autograd.set_detect_anomaly(True)
for i in range(1, N):
v[i] = self.update_pos(v[i].detach(), measurements_covs_l[i-1])
criterion = torch.nn.MSELoss(reduction="sum")
targ = torch.rand(10, 3).double()
loss = criterion(v, targ)
loss = torch.mean(loss)
loss.backward()
return v, p
def update_pos(self, v, measurement_cov):
Omega = torch.eye(3).double()
H = torch.ones((5, self.P_dim)).double()
R = torch.diag(measurement_cov)
Kt = H.t().mm(torch.inverse(R))
# it is indicating inplace error even with this:
# Kt = H.t().mm(R)
dx = Kt.mv(torch.ones(5).double())
dR = self.trans(dx[:9].clone())
v_up = dR.mv(v)
return v_up
def trans(self, xi):
phi = xi[:3].clone()
angle = torch.norm(phi.clone())
if angle.abs().lt(1e-10):
skew_phi = torch.eye(3).double()
J = self.Id3 + 0.5 * skew_phi
Rot = self.Id3 + skew_phi
else:
axis = phi / angle
skew_axis = torch.eye(3).double()
s = torch.sin(angle)
c = torch.cos(angle)
Rot = c * self.Id3
return Rot
net = UpdateModel()
net.run_KF()
【问题讨论】:
标签: pytorch operators backpropagation