【问题标题】:How to prevent gradient computations for certain elements of a tensor in Pytorch如何防止 Pytorch 中张量的某些元素的梯度计算
【发布时间】:2021-08-30 11:51:05
【问题描述】:

说清楚,我不是

  • 询问如何防止梯度传播到某些张量(在这种情况下,您只需为该张量设置 requires_grad = False)。
  • 询问如何防止梯度从整个张量传播(在这种情况下,您只需调用 tensor.detach(),请参阅 this question)。

我想知道如何放弃对每次给出 NaN 梯度的损失张量的某些元素的梯度计算——本质上,为张量的各个元素调用 .detach()。在 TensorFlow 中执行此操作的方法是使用 tf.stop_gradients,请参阅 this question

一些上下文:我的神经网络计算其预测坐标的距离矩阵,如下所示。距离矩阵 D 的条目由d_ij = || coordinates_i - coordinates_j || 给出。我想通过距离矩阵创建步骤进行反向传播。但是,范数函数包括一个平方根,它在 0 处不可微——并且距离矩阵的对角线构造为 0。因此,我得到距离矩阵的对角线的 NaN 梯度。我想掩盖距离矩阵对角线上的渐变。

最小的工作示例:

import torch

def compute_distance_matrix(coordinates):
    L = len(coordinates)
    gram_matrix = torch.mm(coordinates, torch.transpose(coordinates, 0, 1))
    gram_diag = torch.diagonal(gram_matrix, dim1=0, dim2=1)
    # gram_diag: L
    diag_1 = torch.matmul(gram_diag.unsqueeze(-1), torch.ones(1, L).to(coordinates.device))
    # diag_1: L x L
    diag_2 = torch.transpose(diag_1, dim0=0, dim1=1)
    # diag_2: L x L
    distance_matrix = torch.sqrt(diag_1 + diag_2 - (2 * gram_matrix))
    return distance_matrix

# In reality, pred_coordinates is an output of the network, but we initialize it here for a minimal working example
L = 10
pred_coordinates = torch.randn(L, 3, requires_grad=True)
true_coordinates = torch.randn(L, 3, requires_grad=False)
obj = torch.nn.MSELoss()
optimizer = torch.optim.Adam([pred_coordinates])

for i in range(500):
    pred_distance_matrix = compute_distance_matrix(pred_coordinates)
    true_distance_matrix = compute_distance_matrix(true_coordinates)
    loss = obj(pred_distance_matrix, true_distance_matrix)
    loss.backward()
    print(loss.item())
    optimizer.step()

给予

1.2868314981460571
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
...

【问题讨论】:

    标签: pytorch


    【解决方案1】:

    我初始化了一个新矩阵并使用掩码从前一个张量复制具有可微梯度的值(在本例中为非对角项),然后应用非处处可微操作(平方根)到新的张量。这允许梯度只流回具有正掩码的条目。

    import torch
    
    def compute_distance_matrix(coordinates):
        # In reality, pred_coordinates is an output of the network, but we initialize it here for a minimal working example
        L = len(coordinates)
        gram_matrix = torch.mm(coordinates, torch.transpose(coordinates, 0, 1))
        gram_diag = torch.diagonal(gram_matrix, dim1=0, dim2=1)
        # gram_diag: L
        diag_1 = torch.matmul(gram_diag.unsqueeze(-1), torch.ones(1, L).to(coordinates.device))
        # diag_1: L x L
        diag_2 = torch.transpose(diag_1, dim0=0, dim1=1)
        # diag_2: L x L
        squared_distance_matrix = diag_1 + diag_2 - (2 * gram_matrix)
        distance_matrix = torch.zeros_like(squared_distance_matrix)
        mask = ~torch.eye(L, dtype=torch.bool).to(coordinates.device)
        distance_matrix[mask] = torch.sqrt( squared_distance_matrix.masked_select(mask) )
        return distance_matrix
    
    # In reality, pred_coordinates is an output of the network, but we initialize it here for a minimal working example
    L = 10
    pred_coordinates = torch.randn(L, 3, requires_grad=True)
    true_coordinates = torch.randn(L, 3, requires_grad=False)
    obj = torch.nn.MSELoss()
    optimizer = torch.optim.Adam([pred_coordinates])
    
    for i in range(500):
        pred_distance_matrix = compute_distance_matrix(pred_coordinates)
        true_distance_matrix = compute_distance_matrix(true_coordinates)
        loss = obj(pred_distance_matrix, true_distance_matrix)
        loss.backward()
        print(loss.item())
        optimizer.step()
    

    给出:

    1.222102403640747
    1.2191187143325806
    1.2162436246871948
    1.2133947610855103
    1.210543155670166
    1.2076761722564697
    1.204787015914917
    1.2018715143203735
    1.198927402496338
    1.1959534883499146
    1.1929489374160767
    1.1899129152297974
    1.1868458986282349
    1.1837480068206787
    1.180619239807129
    1.1774601936340332
    1.174271583557129
    ...
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2020-01-27
      • 2019-04-12
      • 2020-07-20
      • 2018-06-16
      相关资源
      最近更新 更多