【问题标题】:How to get a Histogram of PyTorch tensors in batches?如何批量获取 PyTorch 张量的直方图?
【发布时间】:2021-10-03 23:18:35
【问题描述】:

有没有办法批量获取torch张量的直方图?

例如: x 是一个形状张量 (64, 224, 224)

# x will have shape of (64, 256)
x = batch_histogram(x, bins=256, min=0, max=255)

【问题讨论】:

    标签: python pytorch histogram tensor


    【解决方案1】:

    不确定,但在我看来这很难做到,而且 PyTorch 没有任何开箱即用的功能。

    直方图是一种统计操作。它本质上是离散的和不可微分的。此外,它们本质上不可矢量化。所以,我认为没有比基于普通循环的解决方案更简单的方法了。

    X = torch.rand(64, 224, 224)
    h = torch.cat([torch.histc(x, bins=256, min=0, max=255) for x in X], 0)
    

    如果有人有更好的解决方案,请随时发布。

    【讨论】:

    • 有没有办法至少在 numpy 或任何其他框架中做到这一点,因为这些批次将非常庞大
    • numpy 的直方图function 的工作方式与PyTorch 相同。注意它说“computed over the flattened array”,这意味着它不支持批处理。你可以像我的火炬示例一样使用 numpy 循环。
    【解决方案2】:

    可以在一行代码中使用torch.nn.functional.one_hot 做到这一点:

    torch.nn.functional.one_hot(data_tensor, num_classes).sum(dim=-2)
    

    基本原理是one_hot 确实尊重批次,并且对于给定张量的最后一维中的每个值 v,创建一个填充为 0 的张量,但第 v 个分量除外,即 1。我们它们对所有这些 one-hot 编码求和,以获得 v 在第二个最后一个维度(这是 tensor_data 中的最后一个维度)的每行数据中出现的次数。

    此方法的一个可能严重的缺点是内存使用,因为每个值都被扩展为大小为num_classes 的张量(因此,tensor_data 的大小乘以num_classes)。然而,这种内存使用是暂时的,因为sum 再次折叠了这个额外的维度,结果通常会小于tensor_data。我说“通常”是因为如果num_classes 远大于tensor_data 的最后一个维度的大小,那么结果会相应地更大。

    这是带有文档的代码,后面是 pytest 测试:

    def batch_histogram(data_tensor, num_classes=-1):
        """
        Computes histograms of integral values, even if in batches (as opposed to torch.histc and torch.histogram).
        Arguments:
            data_tensor: a D1 x ... x D_n torch.LongTensor
            num_classes (optional): the number of classes present in data.
                                    If not provided, tensor.max() + 1 is used (an error is thrown is tensor is empty).
        Returns:
            A D1 x ... x D_{n-1} x num_classes 'result' torch.LongTensor,
            containing histograms of the last dimension D_n of tensor,
            that is, result[d_1,...,d_{n-1}, c] = number of times c appears in tensor[d_1,...,d_{n-1}].
        """
        return torch.nn.functional.one_hot(data_tensor, num_classes).sum(dim=-2)
    
    def test_batch_histogram():
        data = [2, 5, 1, 1]
        expected = [0, 2, 1, 0, 0, 1]
        run_test(data, expected)
    
        data = [
            [2, 5, 1, 1],
            [3, 0, 3, 1],
        ]
        expected = [
            [0, 2, 1, 0, 0, 1],
            [1, 1, 0, 2, 0, 0],
        ]
        run_test(data, expected)
    
        data = [
            [[2, 5, 1, 1], [2, 4, 1, 1], ],
            [[3, 0, 3, 1], [2, 3, 1, 1], ],
        ]
        expected = [
            [[0, 2, 1, 0, 0, 1], [0, 2, 1, 0, 1, 0], ],
            [[1, 1, 0, 2, 0, 0], [0, 2, 1, 1, 0, 0], ],
        ]
        run_test(data, expected)
    
    
    def test_empty_data():
        data = []
        num_classes = 2
        expected = [0, 0]
        run_test(data, expected, num_classes)
    
        data = [[], []]
        num_classes = 2
        expected = [[0, 0], [0, 0]]
        run_test(data, expected, num_classes)
    
        data = [[], []]
        run_test(data, expected=None, exception=RuntimeError)  # num_classes not provided for empty data
    
    
    def run_test(data, expected, num_classes=-1, exception=None):
        data_tensor = torch.tensor(data, dtype=torch.long)
    
        if exception is None:
            expected_tensor = torch.tensor(expected, dtype=torch.long)
            actual = batch_histogram(data_tensor, num_classes)
            assert torch.equal(actual, expected_tensor)
        else:
            with pytest.raises(exception):
                batch_histogram(data_tensor, num_classes)
    

    【讨论】:

      猜你喜欢
      • 2019-01-13
      • 2020-01-03
      • 2016-07-25
      • 1970-01-01
      • 2021-09-14
      • 2020-08-31
      • 2021-11-12
      • 2021-08-22
      • 2019-04-12
      相关资源
      最近更新 更多