【问题标题】:Pytorch Custom Loss Function with If Statement带有 If 语句的 Pytorch 自定义损失函数
【发布时间】:2021-09-26 12:41:00
【问题描述】:

我正在尝试在 Pytorch 中创建一个自定义损失函数,它使用 if 语句评估张量的每个元素并采取相应的行动。

def my_loss(outputs,targets,fin_val):
    if (outputs.detach()-fin_val.detach())*(targets.detach()-fin_val.detach())<0:
        loss=3*(outputs-targets)**2
    else:
        loss=0.3*(outputs-targets)**2
    return loss

我也试过了:

def my_loss(outputs,targets,fin_val):
    if torch.gt((outputs.detach()-fin_val.detach())*(targets.detach()-fin_val.detach()),0):
        loss=0.3*(outputs-targets)**2
    else:
        loss=3*(outputs-targets)**2
    return loss

在这两种情况下,我都会收到以下错误:

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

TIA

【问题讨论】:

    标签: python if-statement pytorch loss-function


    【解决方案1】:

    您收到此错误是因为您传递给if 语句的条件不是布尔值,而是布尔值的张量。只需检查(outputs.detach()-fin_val.detach())*(targets.detach()-fin_val.detach())&lt;0 的性质,它是张量!

    您应该做的是以矢量化形式处理此问题。您可以使用专为此用途构建的torch.where

    torch.where(condition=(outputs - fin_val)*(targets - fin_val) < 0,
                x=3*(outputs-targets)**2,
                y=0.3*(outputs-targets)**2)
    

    这将基于逐点条件张量condition 返回一个张量“xs”和“ys”。然后,您可以根据需要对其进行平均以获得实际损失值。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2019-06-06
      • 2019-05-27
      • 2021-11-13
      • 2018-09-24
      • 1970-01-01
      • 2021-07-30
      • 2020-02-25
      • 2021-02-20
      相关资源
      最近更新 更多