【发布时间】:2021-01-27 17:12:32
【问题描述】:
有两个张量 :inputs_tokens 是一批 20x300 的令牌 ID seq_A 是我的模型输出,大小为 [20, 300, 512](批次中每个标记的 512 个向量)
seq_A.size()
Out[1]: torch.Size([20, 300, 512])
inputs_tokens.size()
torch.Size([20, 300])
我只想获取令牌 101 (CLS) 的向量,如下所示:
cls_tokens = (inputs_tokens == 101)
cls_tokens
Out[4]:
tensor([[ True, False, False, ..., False, False, False],
[ True, False, False, ..., False, False, False],
[ True, False, False, ..., False, False, False], ...
如何对 seq_A 进行切片以仅获取每个批次的 cls_tokens 中为真的向量? 当我这样做时
seq_A[cls_tokens].size()
Out[7]: torch.Size([278, 512])
但我仍然需要它以达到 [20 x N x 512] 的大小(否则我不知道它属于哪个样本)
【问题讨论】:
-
每批是否有
N101 个令牌?
标签: pytorch