【问题标题】:Pytorch: Custom thresholding activation function - gradientPytorch:自定义阈值激活函数 - 梯度
【发布时间】:2021-10-29 06:54:23
【问题描述】:

我创建了一个激活函数类 Threshold,它应该在 one-hot-encoded 图像张量上运行。

该函数在每个通道上执行最小-最大特征缩放,然后进行阈值处理。

class Threshold(nn.Module):
def __init__(self, threshold=.5):
    super().__init__()
    if threshold < 0.0 or threshold > 1.0:
        raise ValueError("Threshold value must be in [0,1]")
    else:
        self.threshold = threshold

def min_max_fscale(self, input):
    r"""
    applies min max feature scaling to input. Each channel is treated individually.
    input is assumed to be N x C x H x W (one-hot-encoded prediction)
    """
    for i in range(input.shape[0]):
        # N
        for j in range(input.shape[1]):
            # C
            min = torch.min(input[i][j])
            max = torch.max(input[i][j])
            input[i][j] = (input[i][j] - min) / (max - min)
    return input

def forward(self, input):
    assert (len(input.shape) == 4), f"input has wrong number of dims. Must have dim = 4 but has dim {input.shape}"

    input = self.min_max_fscale(input)
    return (input >= self.threshold) * 1.0

当我使用该函数时,我得到以下错误,因为我假设梯度不是自动计算的。

Variable._execution_engine.run_backward(RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

我已经看过How to properly update the weights in PyTorch?,但不知道如何将它应用到我的案例中。

如何计算这个函数的梯度?

感谢您的帮助。

【问题讨论】:

    标签: pytorch gradient activation


    【解决方案1】:

    问题是您正在操作和覆盖元素,autograd 无法跟踪此操作时间。相反,您应该坚持使用内置函数。您的示例并不难解决:您正在寻找沿input.shape[0] x input.shape[1] 检索最小值和最大值。然后,您将以矢量化形式一次性缩放整个张量不涉及 for 循环!

    沿多个轴计算最小值/最大值的一种方法是将它们展平:

    >>> x_f = x.flatten(2)
    

    然后,在保持所有形状的同时,找到展平轴上的最小值-最大值:

    >>> x_min = x_f.min(axis=-1, keepdim=True).values
    >>> x_max = x_f.max(axis=-1, keepdim=True).values
    

    生成的min_max_fscale 函数如下所示:

    class Threshold(nn.Module):
        def min_max_fscale(self, x):
            r"""
            Applies min max feature scaling to input. Each channel is treated individually. 
            Input is assumed to be N x C x H x W (one-hot-encoded prediction)
            """
            x_f = x.flatten(2)
            x_min, x_max = x_f.min(-1, True).values, x_f.max(-1, True).values
    
            x_f = (x_f - x_min) / (x_max - x_min)
            return x_f.reshape_as(x)
    

    重要提示:

    您会注意到您现在可以在 min_max_fscale... 上进行反向传播...但不能在 forward 上进行反向传播。这是因为您应用的布尔条件不是可微分运算。

    【讨论】:

      猜你喜欢
      • 2021-02-22
      • 1970-01-01
      • 2018-04-15
      • 2018-09-30
      • 2019-07-10
      • 2021-07-30
      • 1970-01-01
      • 2021-10-07
      • 1970-01-01
      相关资源
      最近更新 更多