【问题标题】:How to convert a pytorch tensor of ints to a tensor of booleans?如何将 pytorch 整数张量转换为布尔张量?
【发布时间】:2019-05-02 21:32:15
【问题描述】:

我想将整数张量转换为布尔张量。

具体来说,我希望能够拥有一个将tensor([0,10,0,16]) 转换为tensor([0,1,0,1]) 的函数

这在 Tensorflow 中很简单,只需使用 tf.cast(x,tf.bool)

我希望强制转换将所有大于 0 的整数更改为 1,并将所有等于 0 的整数更改为 0。这相当于大多数语言中的 !!

由于 pytorch 似乎没有专用的布尔类型可以转换,所以这里最好的方法是什么?

编辑:我正在寻找一种矢量化解决方案,而不是循环遍历每个元素。

【问题讨论】:

  • 在每个元素上调用 bool(int)。或者在 numpy 中:使用 array.astype(...)
  • 这是需要for循环的简单解决方案,是的。但是有矢量化的解决方案吗?
  • astype 版本几乎肯定是矢量化的。
  • @ThomasLang 在 pytorch 中没有 .astype,所以必须要 convert to numpy-> cast -> load to pytorch 哪个 IMO 效率低下

标签: python casting boolean pytorch tensor


【解决方案1】:

您可以使用如下所示的比较:

 >>> a = tensor([0,10,0,16])
 >>> result = (a == 0)
 >>> result
 tensor([ True, False,  True, False])

【讨论】:

    【解决方案2】:

    您正在寻找的是为给定的整数张量生成一个布尔掩码。为此,您可以使用简单的比较运算符 (>) 或使用 torch.gt() 来简单地检查条件:“张量中的值是否大于 0”,这将为我们提供所需的结果。

    # input tensor
    In [76]: t   
    Out[76]: tensor([ 0, 10,  0, 16])
    
    # generate the needed boolean mask
    In [78]: t > 0      
    Out[78]: tensor([0, 1, 0, 1], dtype=torch.uint8)
    

    # sanity check
    In [93]: mask = t > 0      
    
    In [94]: mask.type()      
    Out[94]: 'torch.ByteTensor'
    

    注意:在 PyTorch 1.4+ 版本中,上述操作会返回'torch.BoolTensor'

    In [9]: t > 0  
    Out[9]: tensor([False,  True, False,  True])
    
    # alternatively, use `torch.gt()` API
    In [11]: torch.gt(t, 0)
    Out[11]: tensor([False,  True, False,  True])
    

    如果您确实想要单个位(0s 或 1s),请使用:

    In [14]: (t > 0).type(torch.uint8)   
    Out[14]: tensor([0, 1, 0, 1], dtype=torch.uint8)
    
    # alternatively, use `torch.gt()` API
    In [15]: torch.gt(t, 0).int()
    Out[15]: tensor([0, 1, 0, 1], dtype=torch.int32)
    

    此功能请求问题中讨论了此更改的原因:issues/4764 - Introduce torch.BoolTensor ...


    TL;DR:简单的一个班轮

    t.bool().int()
    

    【讨论】:

    • 这会在 PyTorch 1.4.0 中返回“torch.BoolTensor”。
    【解决方案3】:

    将布尔值转换为数值:

    a = torch.tensor([0,4,0,0,5,0.12,0.34,0,0])
    print(a.gt(0)) # output in boolean dtype
    # output: tensor([False,  True, False, False,  True,  True,  True, False, False])
    
    print(a.gt(0).to(torch.float32)) # output in float32 dtype
    # output: tensor([0., 1., 0., 0., 1., 1., 1., 0., 0.])
    

    【讨论】:

      【解决方案4】:

      另一种选择是简单地做:

      temp = torch.tensor([0,10,0,16])
      temp.bool()
      #Returns
      tensor([False,  True, False,  True])
      

      【讨论】:

        【解决方案5】:

        PyTorch 的to(dtype) 方法有方便的data-type named aliases。您可以直接拨打bool:

        >>> t.bool()
        tensor([False,  True, False,  True])
        
        >>> t.bool().int()
        tensor([0, 1, 0, 1], dtype=torch.int32)
        

        【讨论】:

          猜你喜欢
          • 2022-10-17
          • 2020-08-05
          • 2019-07-29
          • 2021-10-11
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          • 2019-06-13
          相关资源
          最近更新 更多