【发布时间】:2020-11-07 16:28:17
【问题描述】:
背景:
随着这个question 在使用 bert 对序列进行分类时,模型使用表示分类任务的“[CLS]”标记。根据论文:
每个序列的第一个标记总是一个特殊的分类 令牌([CLS])。这个token对应的最终隐藏状态是 用作分类的聚合序列表示 任务。
查看Huggingfaces 存储库,他们的BertForSequenceClassification 使用了bert pooler 方法:
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
我们可以看到他们采用第一个标记 (CLS) 并将其用作整个句子的表示。具体来说,他们执行hidden_states[:, 0],这看起来很像从每个状态中获取第一个元素而不是获取第一个标记隐藏状态?
我的问题:
我不明白的是他们如何将整个句子中的信息编码到这个令牌中? CLS 标记是一个常规标记,它有自己的嵌入向量,可以“学习”句子级别的表示吗?为什么我们不能只使用隐藏状态的平均值(编码器的输出)并用它来分类?
编辑:想了想:因为我们使用 CLS 令牌隐藏状态来预测,所以 CLS 令牌嵌入是否正在接受分类任务的训练,因为这是用于分类(因此是传播到其权重的误差的主要贡献者?)
【问题讨论】:
标签: python transformer huggingface-transformers bert-language-model