【问题标题】:How to remove inplace operation error in Pytorch?如何消除 Pytorch 中的就地操作错误?
【发布时间】:2021-05-10 04:13:41
【问题描述】:

我从以下 Pytorch 代码中收到此错误:

RuntimeError:梯度计算所需的变量之一已被就地操作修改:[torch.DoubleTensor [3]] 版本为 10;预计版本 9。

正如所见,代码没有就地操作。

import torch
device = torch.device('cpu')
class MesNet(torch.nn.Module):
        def __init__(self):
            super(MesNet, self).__init__()
            self.cov_lin = torch.nn.Sequential(torch.nn.Linear(6, 5)).double()
        def forward(self, u):
            z_cov = self.cov_lin(u.transpose(0, 2).squeeze(-1))
            return z_cov 
class UpdateModel(torch.nn.Module):
    def __init__(self):
        torch.nn.Module.__init__(self)
        self.P_dim = 18
        self.Id3 = torch.eye(3).double()
    def run_KF(self):
        N = 10
        u = torch.randn(N, 6).double()
        v = torch.zeros(N, 3).double()
        model = MesNet()
        measurements_covs_l = model(u.t().unsqueeze(0))
        # remember to remove this afterwards
        torch.autograd.set_detect_anomaly(True)
        for i in range(1, N):
            v[i] = self.update_pos(v[i].detach(), measurements_covs_l[i-1])

        criterion = torch.nn.MSELoss(reduction="sum")
        targ = torch.rand(10, 3).double()
        loss = criterion(v, targ)
        loss = torch.mean(loss)
        loss.backward()
        return v, p


    def update_pos(self, v, measurement_cov):
        Omega = torch.eye(3).double() 
        H = torch.ones((5, self.P_dim)).double()
        R = torch.diag(measurement_cov)
        Kt = H.t().mm(torch.inverse(R))
        # it is indicating inplace error even with this: 
        # Kt = H.t().mm(R)
        dx = Kt.mv(torch.ones(5).double())
        dR = self.trans(dx[:9].clone())
        v_up = dR.mv(v)
        return v_up

    def trans(self, xi):
        phi = xi[:3].clone()
        angle = torch.norm(phi.clone())

        if angle.abs().lt(1e-10):

            skew_phi = torch.eye(3).double()
            J = self.Id3 + 0.5 * skew_phi
            Rot = self.Id3 + skew_phi
        else:
            axis = phi / angle
            skew_axis = torch.eye(3).double()
            s = torch.sin(angle)
            c = torch.cos(angle)

            Rot = c * self.Id3
        return Rot
net =  UpdateModel()
net.run_KF()

【问题讨论】:

    标签: pytorch operators backpropagation


    【解决方案1】:

    我认为问题在于您正在覆盖 v[i] 元素。

    您可以改为从循环中构造一个辅助列表v_,然后将其转换为张量:

    v_ = [v[0]]
    for i in range(1, N):
        v_.append(self.update_pos(v[i].detach(), measurements_covs_l[i-1]))
    v = torch.stack(v_)
    

    【讨论】:

    • 如何避免以下就地运算符? H = torch.zeros((5, self.P_dim)) H[:3, 3:6] = Rot.clone().t()[:] H[3:, 6:8] = torch.eye(2).clone()
    • 我给你的替代方案工作正常,因为它不会覆盖变量v中包含的对象。
    • 我明白了。您能否也解决有关重新实现这些就地运算符的问题?
    • 由于更新步骤的迭代,代码在 CPU 上运行得更快。我还意识到,对于推理的更新步骤,numpy 比 pytorch 运行得更快。有办法优化吗?
    猜你喜欢
    • 2020-09-26
    • 2020-07-11
    • 2019-01-19
    • 2020-08-25
    • 2019-12-21
    • 2021-03-04
    • 2023-01-17
    • 2020-06-21
    • 2020-11-04
    相关资源
    最近更新 更多