【问题标题】:Filter out NaN values from a PyTorch N-Dimensional tensor从 PyTorch N 维张量中过滤掉 NaN 值
【发布时间】:2021-02-12 02:18:42
【问题描述】:

这个问题很像to filtering np.nan values from pytorch in a -Dimensional tensor。不同之处在于我想将相同的概念应用于 2 维或更高维的张量。

我有一个看起来像这样的张量:

import torch

tensor = torch.Tensor(
[[1, 1, 1, 1, 1],
 [float('nan'), float('nan'), float('nan'), float('nan'), float('nan')],
 [2, 2, 2, 2, 2]]
)
>>> tensor.shape
>>> [3, 5]

我想找到最pythonic / PyTorch 的方式来过滤(删除)张量中nan 的行。通过沿第一个(0th 轴)过滤这个tensor,我想获得一个看起来像这样的filtered_tensor

>>> print(filtered_tensor)
>>> torch.Tensor(
[[1, 1, 1, 1, 1],
 [2, 2, 2, 2, 2]]
)
>>> filtered_tensor.shape
>>> [2, 5]

【问题讨论】:

    标签: python python-3.x pytorch filtering nan


    【解决方案1】:

    使用 PyTorch 的 isnan()any() 使用获得的布尔掩码对 tensor 的行进行切片,如下所示:

    filtered_tensor = tensor[~torch.any(tensor.isnan(),dim=1)]
    

    请注意,这将删除其中包含 nan 值的任何行。如果您只想删除所有值为nan 的行,请将torch.any 替换为torch.all

    对于 N 维张量,您可以将除第一个 dim 之外的所有 dims 展平并应用与上述相同的过程:

    #Flatten:
    shape = tensor.shape
    tensor_reshaped = tensor.reshape(shape[0],-1)
    #Drop all rows containing any nan:
    tensor_reshaped = tensor_reshaped[~torch.any(tensor_reshaped.isnan(),dim=1)]
    #Reshape back:
    tensor = tensor_reshaped.reshape(tensor_reshaped.shape[0],*shape[1:])
    

    【讨论】:

    • 美丽。这正是我一直在寻找的。我应该检查一下torch.any() 有一个dim 参数。展示如何在展平时做同样的事情的额外积分!我会接受这个作为答案,但你有一个小错误。 Tensors 没有 t.isnan() 函数,它只是顶级 torch.isnan(t) 函数。如果您不介意,请更改它,我会接受您的回答:D。
    • Tensor 确实有一个 isnan 方法,请查看 pytorch.org/docs/stable/tensors.html 。这就是代码完美运行的原因。你是对的,它确实在后台调用torch.isnan
    • 你是对的!我使用的是旧版本的 PyTorch。再次感谢!我已经接受了答案
    • N维码对我不起作用;从tensor = torch.arange(600, dtype=torch.float32).reshape(1, 3, 20, 10) 开始,然后是tensor[0, 2, 0, 0] = float('nan'),结果的形状为 (0, 3, 20, 10)
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2017-06-30
    • 1970-01-01
    • 2021-03-02
    • 2019-12-25
    • 2020-08-31
    • 2018-09-02
    • 2020-09-26
    相关资源
    最近更新 更多