【问题标题】:pytorch autograd : getting pixel grid tensor from coordinates tensor in a differentiable waypytorch autograd:以可微的方式从坐标张量获取像素网格张量
【发布时间】:2021-08-19 10:09:45
【问题描述】:

我的模型输出包含画布内矩形的坐标,我试图从坐标表示中获取此输出的像素表示,然后将损失应用于像素表示:

    # get prediction
    ypred=forward(x,w)

    # rasterize pred + test
    ytrain=rasterize(ytrain,300,600)
    ypred=rasterize(ypred,300,600)
    
    # update loss
    loss = get_loss(ytrain, ypred)

    # get gradient
    loss.backward()
    
    # update weights
    with torch.no_grad():
        w -= lr * w.grad

我已经构建了这两个替代的光栅化玩具函数:

# rasterize tensor : for loop
def rasterize_toy(tn,w,h):
    nsamples=tn.size()[0]
    #nsamples=4
    vtn=torch.zeros(nsamples,h,w,3,dtype=torch.float, requires_grad=True)
    #vtn=torch.empty(nsamples,h,w,3,dtype=torch.float, requires_grad=True)
    for i in range(nsamples): # for each sample
        top=(tn[i]*5).long()
        vtn[i]=add_tn_bg(vtn[i],h,w)
        #vtn[i]=get_tn_bg(h,w)
        vtn[i,top:,:,0]=255/255
        vtn[i,top:,:,1]=255/255
        vtn[i,top:,:,2]=255/255
    #vtn=vtn.float()
    return vtn

# rasterize tensor : index_put
def rasterize_toy2(tn,w,h):
    nsamples=tn.size()[0]
    top=(tn[0]*10).long()
    print("top",top)
    v=100#1.0 #255
    #vtn=torch.zeros(nsamples,h,w,3,dtype=torch.float, requires_grad=True)
    vtn=torch.zeros(nsamples,h,w,dtype=torch.float, requires_grad=True)
    indices=[(torch.ones(w)*top).long(),
             torch.arange(0,w).long()]
    values=torch.ones(w)*v    
    vtn[0]=vtn[0].index_put(indices, values)
    return vtn

但在光栅化步骤之后调用loss.backward() 时,它们都会产生此错误:

RuntimeError: 叶变量已被移动到图内部

我的问题似乎与这个尚未解决的问题非常相似:

来源 1

PyTorch: Differentiable operations to go from coordinate tensor to grid tensor

我还检查了堆栈溢出之外的以下来源:

来源 2

link : GitHub - ksheng-/fast-differentiable-rasterizer: PyTorch 的可微分贝塞尔曲线光栅化器

问题:虽然这个 git 提出了光栅化数据结构的方法,但在我看来,它不允许使用可微变量作为最终光栅化图像的索引。

来源 3

链接:https://discuss.pytorch.org/t/leaf-variable-moved-into-graph-interior/17489

问题:

  • masked_scatter、gather 和 grid_sample 函数似乎与我想要做的不匹配
  • index_put 似乎符合我的需求,但我的第二个光栅化函数基于它,它生成的错误与基于 for 循环的光栅化函数相同

提前感谢您的帮助

【问题讨论】:

    标签: python machine-learning deep-learning pytorch autograd


    【解决方案1】:

    在将模型输出输入loss.backward() 操作之前,我设法将其栅格化,如下所示:

    解决方案有点讨厌,需要以像素表示的形状初始化坐标表示 x,x 值仅占该形状的一小部分。然后光栅化函数在完整的像素表示中写入新值,从它包含的几个初始化值,再加上更新的权重。

    这是训练协议:

        # Training
        lr = 0.01
        for iepoch in range(nepochs):
    
            # get prediction
            ypred=mforward(x,w)
    
            # rasterize pred + test
            y=mrast(y,300,600)
            ypred=mrast(ypred,300,600)
            
            # update loss
            loss = get_loss(y, ypred)
    
            # get gradient
            loss.backward()
            
            # update weights
            with torch.no_grad():
                w -= lr * w.grad
    
            # reset gradient #37
            w.grad.zero_()
    

    这里是光栅化函数:

    def rasterize_toy3(tn,w,h):
        htn=torch.ones(nsamples,h,w,nc,dtype=torch.float)*0
        for i in range(nsamples): # for each sample
            top=(tn[i,0,0,0]*5).long()
            tn[i,top:,:,0]=100/255
            tn[i,top:,:,1]=100/255
            tn[i,top:,:,2]=100/255
        return tn
    

    这里是数据初始化函数:

    def set_data():
        
        ## data : origin
        if 1==0:
            x = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
            y = torch.tensor([2, 4, 6, 8], dtype=torch.float32)
        
        ## data : alt
        if 1==1:
            
            nsamples,h,w,nc=4,60,30,3
            x=torch.ones(nsamples,h,w,nc,dtype=torch.float)*0
            x[:,0,0,0]=torch.tensor([1, 2, 3, 4], dtype=torch.float32)
            y=torch.ones(nsamples,h,w,nc,dtype=torch.float)*0
            y[:,0,0,0]=torch.tensor([2, 4, 6, 8], dtype=torch.float32)
    
        return x,y
    

    如果有人有更优雅的解决方案,我很感兴趣。

    【讨论】:

      猜你喜欢
      • 2020-10-21
      • 2021-09-20
      • 2021-01-19
      • 2021-10-11
      • 1970-01-01
      • 2020-06-28
      • 1970-01-01
      • 1970-01-01
      • 2020-08-31
      相关资源
      最近更新 更多