【发布时间】:2021-02-12 02:18:42
【问题描述】:
这个问题很像to filtering np.nan values from pytorch in a -Dimensional tensor。不同之处在于我想将相同的概念应用于 2 维或更高维的张量。
我有一个看起来像这样的张量:
import torch
tensor = torch.Tensor(
[[1, 1, 1, 1, 1],
[float('nan'), float('nan'), float('nan'), float('nan'), float('nan')],
[2, 2, 2, 2, 2]]
)
>>> tensor.shape
>>> [3, 5]
我想找到最pythonic / PyTorch 的方式来过滤(删除)张量中nan 的行。通过沿第一个(0th 轴)过滤这个tensor,我想获得一个看起来像这样的filtered_tensor:
>>> print(filtered_tensor)
>>> torch.Tensor(
[[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2]]
)
>>> filtered_tensor.shape
>>> [2, 5]
【问题讨论】:
标签: python python-3.x pytorch filtering nan