【问题标题】:Unique values in PyTorch tensorPyTorch 张量中的唯一值
【发布时间】:2017-12-13 01:59:20
【问题描述】:

我想在 PyTorch 张量中找到不同的值。有没有 Tensorflow 的 unique op 的有效模拟?

【问题讨论】:

标签: python pytorch unique


【解决方案1】:

您可以转换为 numpy 数组并使用 numpy 的内置 unique 函数:

def unique(tensor1d):
    t, idx = np.unique(tensor1d.numpy(), return_inverse=True)
    return torch.from_numpy(t), torch.from_numpy(idx)  

例子:

t, idx = unique(torch.LongTensor([1, 1, 2, 4, 4, 4, 7, 8, 8]))  
# t --> [1, 2, 4, 7, 8]
# idx --> [0, 0, 1, 2, 2, 2, 3, 4, 4]

【讨论】:

  • 我认为这会起作用,但我宁愿避免 numpy 操作,因为它可能需要太多时间。无论如何,我认为这是目前唯一的解决方案,谢谢。
【解决方案2】:
  1. torch.eq()获取两个张量之间的共同项
  2. 获取索引并连接张量
  3. 终于通过torch.unique获得普通物品:
import torch as pt

a = pt.tensor([1,2,3,2,3,4,3,4,5,6])
b = pt.tensor([7,2,3,2,7,4,9,4,9,8])

equal_data = pt.eq(a, b)
pt.unique(pt.cat([a[equal_data],b[equal_data]]))

【讨论】:

  • 如果您在代码中添加一些解释会很好。
  • 我们得到两个张量之间的共同项。等价于@2 tensor.eq() 获取索引和连接张量最终得到 'torch.unique' 的常用项帮助。
  • a[equal_data].unique() 在避开猫的同时也应该做到这一点;)
【解决方案3】:

0.4.0 中有一个torch.unique() 方法

torch <= 0.3.1你可以试试:

import torch
import numpy as np

x = torch.rand((3,3)) * 10
np.unique(x.round().numpy())

【讨论】:

  • 不错..但唯一的(0.4.0)目前只有 CPU .. :-)
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2021-10-11
  • 1970-01-01
  • 2022-06-21
  • 2021-02-07
  • 1970-01-01
  • 2021-08-01
相关资源
最近更新 更多