【问题标题】:Retrieve elements from a 3D tensor with a 2D index tensor使用 2D 索引张量从 3D 张量中检索元素
【发布时间】:2021-01-05 01:42:33
【问题描述】:

我正在玩 GPT2,我有 2 个张量:

O:形状为 (B, S-1, V) 的输出张量,其中 B 是批量大小 S 是时间步数,V 是词汇量。这是生成模型的输出,并沿第二维进行了 softmax。

L:一个 2D 张量形状 (B, S-1),其中每个元素是每个样本的每个时间步长的正确标记的索引。这基本上是标签。

我想根据张量 L 从张量 O 中提取相应正确标记的预测概率,这样我最终将得到一个二维张量形状 (B, S)。除了使用循环之外,还有其他有效的方法吗?

【问题讨论】:

  • 你能添加样本张量和预期输出吗?

标签: nlp pytorch huggingface-transformers


【解决方案1】:

作为参考,我的回答基于this Medium article
本质上,您的答案在于torch.gather,假设您的两个张量都只是常规的torch.Tensors(或可以转换为一个)。

import torch

# Specify some arbitrary dimensions for now
B = 3
V = 6
S = 4

# Make example reproducible
torch.manual_seed(42)

# L necessarily has to be a torch.LongTensor, otherwise indexing will fail.
L = torch.randint(0, V, size=[B, S])

O = torch.rand([B, S, V])

# Now collect the results. L needs to have similar dimension,
# except in the axis you want to collect along.
X = torch.gather(O, dim=2, index=L.unsqueeze(dim=2))

# Make sure X has no "unnecessary" dimension
X = X.squeeze(dim=2)

很难看出这是否会产生完全正确的结果,这就是为什么我包含了一个随机种子,它使示例在结果中具有确定性,并且您可以轻松验证它是否可以获得所需的结果。但是,为了澄清起见,也可以使用低维张量,这样会更清楚 torch.gather 究竟做了什么。

请注意,torch.gather 理论上还允许您在同一行中索引多个索引。这意味着,如果您获得了一个多值示例,其中多个值是正确的,您可以类似地使用形状为 [B, S, number_of_correct_samples] 的张量 L

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2022-01-21
    • 2020-01-27
    • 2021-09-14
    • 1970-01-01
    相关资源
    最近更新 更多