【发布时间】:2021-10-03 23:18:35
【问题描述】:
有没有办法批量获取torch张量的直方图?
例如:
x 是一个形状张量 (64, 224, 224)
# x will have shape of (64, 256)
x = batch_histogram(x, bins=256, min=0, max=255)
【问题讨论】:
标签: python pytorch histogram tensor
有没有办法批量获取torch张量的直方图?
例如:
x 是一个形状张量 (64, 224, 224)
# x will have shape of (64, 256)
x = batch_histogram(x, bins=256, min=0, max=255)
【问题讨论】:
标签: python pytorch histogram tensor
不确定,但在我看来这很难做到,而且 PyTorch 没有任何开箱即用的功能。
直方图是一种统计操作。它本质上是离散的和不可微分的。此外,它们本质上不可矢量化。所以,我认为没有比基于普通循环的解决方案更简单的方法了。
X = torch.rand(64, 224, 224)
h = torch.cat([torch.histc(x, bins=256, min=0, max=255) for x in X], 0)
如果有人有更好的解决方案,请随时发布。
【讨论】:
numpy 的直方图function 的工作方式与PyTorch 相同。注意它说“computed over the flattened array”,这意味着它不支持批处理。你可以像我的火炬示例一样使用 numpy 循环。
可以在一行代码中使用torch.nn.functional.one_hot 做到这一点:
torch.nn.functional.one_hot(data_tensor, num_classes).sum(dim=-2)
基本原理是one_hot 确实尊重批次,并且对于给定张量的最后一维中的每个值 v,创建一个填充为 0 的张量,但第 v 个分量除外,即 1。我们它们对所有这些 one-hot 编码求和,以获得 v 在第二个最后一个维度(这是 tensor_data 中的最后一个维度)的每行数据中出现的次数。
此方法的一个可能严重的缺点是内存使用,因为每个值都被扩展为大小为num_classes 的张量(因此,tensor_data 的大小乘以num_classes)。然而,这种内存使用是暂时的,因为sum 再次折叠了这个额外的维度,结果通常会小于tensor_data。我说“通常”是因为如果num_classes 远大于tensor_data 的最后一个维度的大小,那么结果会相应地更大。
这是带有文档的代码,后面是 pytest 测试:
def batch_histogram(data_tensor, num_classes=-1):
"""
Computes histograms of integral values, even if in batches (as opposed to torch.histc and torch.histogram).
Arguments:
data_tensor: a D1 x ... x D_n torch.LongTensor
num_classes (optional): the number of classes present in data.
If not provided, tensor.max() + 1 is used (an error is thrown is tensor is empty).
Returns:
A D1 x ... x D_{n-1} x num_classes 'result' torch.LongTensor,
containing histograms of the last dimension D_n of tensor,
that is, result[d_1,...,d_{n-1}, c] = number of times c appears in tensor[d_1,...,d_{n-1}].
"""
return torch.nn.functional.one_hot(data_tensor, num_classes).sum(dim=-2)
def test_batch_histogram():
data = [2, 5, 1, 1]
expected = [0, 2, 1, 0, 0, 1]
run_test(data, expected)
data = [
[2, 5, 1, 1],
[3, 0, 3, 1],
]
expected = [
[0, 2, 1, 0, 0, 1],
[1, 1, 0, 2, 0, 0],
]
run_test(data, expected)
data = [
[[2, 5, 1, 1], [2, 4, 1, 1], ],
[[3, 0, 3, 1], [2, 3, 1, 1], ],
]
expected = [
[[0, 2, 1, 0, 0, 1], [0, 2, 1, 0, 1, 0], ],
[[1, 1, 0, 2, 0, 0], [0, 2, 1, 1, 0, 0], ],
]
run_test(data, expected)
def test_empty_data():
data = []
num_classes = 2
expected = [0, 0]
run_test(data, expected, num_classes)
data = [[], []]
num_classes = 2
expected = [[0, 0], [0, 0]]
run_test(data, expected, num_classes)
data = [[], []]
run_test(data, expected=None, exception=RuntimeError) # num_classes not provided for empty data
def run_test(data, expected, num_classes=-1, exception=None):
data_tensor = torch.tensor(data, dtype=torch.long)
if exception is None:
expected_tensor = torch.tensor(expected, dtype=torch.long)
actual = batch_histogram(data_tensor, num_classes)
assert torch.equal(actual, expected_tensor)
else:
with pytest.raises(exception):
batch_histogram(data_tensor, num_classes)
【讨论】: