【发布时间】:2021-05-19 06:33:25
【问题描述】:
我想对形状为 (N, C, H, W) 的概率分布张量进行采样,其中维度 1(大小 C)包含具有“C”个可能性的归一化概率分布。是否有一个 pytorch 函数可以有效地并行采样张量中的所有分布?我只需要对每个分布进行一次采样,因此结果可能是具有相同形状的单热张量或形状为 (N, 1, H, W) 的索引张量。
【问题讨论】:
标签: pytorch
我想对形状为 (N, C, H, W) 的概率分布张量进行采样,其中维度 1(大小 C)包含具有“C”个可能性的归一化概率分布。是否有一个 pytorch 函数可以有效地并行采样张量中的所有分布?我只需要对每个分布进行一次采样,因此结果可能是具有相同形状的单热张量或形状为 (N, 1, H, W) 的索引张量。
【问题讨论】:
标签: pytorch
我没有看到单一的采样函数,但我能够通过计算累积概率分几个步骤对张量进行采样,独立采样每个点,然后选择在分布维度中采样 1 的第一个点:
reverse_cumulative = torch.flip(torch.cumsum(torch.flip(probabilities, [1]), dim=1), [1])
cumulative = probabilities / reverse_cumulative
sampled = (torch.rand(cumulative.shape, device=device()) <= cumulative)
idxs = sampled * one_hot
idxs[~sampled] = self.tile_count
sampled_idxs = idxs.min(dim=1).indices
【讨论】: