【问题标题】:Optional tensors in PyTorch c++ extensionPyTorch c++ 扩展中的可选张量
【发布时间】:2019-07-08 04:29:32
【问题描述】:

我正在为 pytorch 编写 C++ 扩展,并使用 c++ api 来执行此操作。对于我的 forward 函数,我需要传递一个可选的张量。在函数内部,我想根据是否传递了这个可选参数来做不同的事情。通常,我们在 C++ 中使用 NULL 作为可选指针参数,并在函数内部检查指针是否为 NULL。我不知道如何为 at::Tensor 类型的 Torch 的 c++ api 执行此操作。

void xyz_forward(
    const at::Tensor xyz1, 
    const at::Tensor xyz2, 
    const at::Tensor optional_constraints = something)
{
     if(optional_constraints){
        //do something
     }else{
        //do something else
     }
}

请注意,我不能做const at::Tensor optional_constraints = at::ones 之类的,因为该参数可以采用任何实际值并且可以具有不同的大小/形状。我不能为它分配一个数值作为可选参数。是否有对应的 NULL

【问题讨论】:

  • 也许我不明白,但你不能检查一下optional_constraints == nullptr吗?
  • @Coolness 不幸的是optional_constrains 不是指针。
  • 啊,我明白了。谢谢。

标签: c++ pytorch torch


【解决方案1】:

一种可能性是将std::optional 用作std::optional<at::Tensor> optional_constraints = std::nullopt。它可以根据上下文转换为bool,因此您可以使用if (optional_constraints) 进行检查。传一个则使用.value()方法获取张量,否则默认为std::nullopt

【讨论】:

    【解决方案2】:

    因为我找不到类似的东西,例如。 API 中的 OpenCV noArray()(主要用于传递可选矩阵,如掩码),我建议您为此目的使用重载函数

    void xyz_forward(
        const at::Tensor xyz1, 
        const at::Tensor xyz2)
    {
         // optional tensor wasnt passed
    }
    

    void xyz_forward(
        const at::Tensor xyz1, 
        const at::Tensor xyz2, 
        const at::Tensor optional_constraints)
    {
         // optional tensor passed
    }
    

    【讨论】:

      猜你喜欢
      • 2020-02-25
      • 2021-02-10
      • 1970-01-01
      • 2022-11-26
      • 1970-01-01
      • 2022-10-17
      • 1970-01-01
      • 2021-02-01
      • 1970-01-01
      相关资源
      最近更新 更多