【问题标题】:How can I get hidden_states from BertForSequenceClassification?如何从 BertForSequenceClassification 获取 hidden_​​states?
【发布时间】:2020-01-04 16:39:46
【问题描述】:

我阅读了官方教程(https://huggingface.co/transformers/model_doc/bert.html) 并尝试设置配置,但它不起作用。

from transformers import PretrainedConfig
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
model.config.output_hidden_states = True
model.load_state_dict(torch.load('../parameter.pkl'))
model.cuda()
output = model(input)

【问题讨论】:

    标签: python pytorch bert-language-model


    【解决方案1】:

    输出应该是一个包含隐藏状态的列表。我希望因为您正在加载默认情况下可能没有输出隐藏状态的parameter.pkl,它会将您的config.output_hidden_states 覆盖为False?看看如果在加载 state_dict 后将其设置为 True 会发生什么?

    【讨论】:

    • 感谢您的建议。我按照您的建议更改了顺序,但是输出仅接收分类张量。 array([[ 5.155039 , -5.1482654], [-1.7462035, 3.0982263], [ 5.295436 , -5.6735516]], dtype=float32)
    • 如果您只是注释掉 model.load_state_dict 会发生什么?如果它在注释掉后有效,那么您就知道错误在于加载状态字典。
    猜你喜欢
    • 2020-08-03
    • 1970-01-01
    • 2023-04-04
    • 2021-06-23
    • 2018-03-31
    • 1970-01-01
    • 2020-11-07
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多