【问题标题】:any similar function like df.mask for tensor in pytorch?pytorch中的张量是否有类似df.mask之类的功能?
【发布时间】:2020-03-21 02:11:38
【问题描述】:

我想用 -5 替换二维张量中的所有 0。

使用数据框,我可以轻松做到这一点:

df = df.mask(df=0, -5) 

但这不适用于张量。我试过了:

 y = torch.where(y = 0, -5, y) 

【问题讨论】:

    标签: python tensorflow pytorch tensor


    【解决方案1】:

    一般有两种方式。

    上面由prhmma 给出的一个是使用像y[y == 0] = -5 这样的就地突变。它既好又高效,但会破坏 autograd 操作。所以如果你想让梯度流过 y,你不应该那样做。

    另一种方法是使用torch.where,正如您所尝试的那样。正确的咒语是

    y = torch.where(y == 0, torch.tensor(-5), y)
    

    或者,如果您想与设备和数据类型无关

    five = torch.tensor(-5, dtype=y.dtype, device=y.device)
    y = torch.where(y == 0, five, y)
    

    where 不接受标量这一事实是一个令人讨厌的剪纸,但这就是 ATM 的方式。请注意,虽然选择本身是离散的且显然不可微分,但此操作将使梯度流过两个操作数。

    【讨论】:

    • 无论什么分段函数y = -5 if x == 0 else x 在 0 处不可微分,它甚至都不是连续的。所以有效的反向传播是不可能的
    • 以上评论只是为了强调您已经提到的这一重要事实,以防@LinmeiShang 计划将其用于优化目的。
    【解决方案2】:

    很简单,就用这个

    y[y==0]=-5

    【讨论】:

    • 不是特殊函数,只是python标准
    • @LinmeiShang 一般来说,许多numpy技巧(与索引有关)也适用于pytorch。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2012-08-01
    • 2023-04-07
    • 1970-01-01
    • 1970-01-01
    • 2015-11-21
    相关资源
    最近更新 更多