【问题标题】:calculate perplexity in pytorch在pytorch中计算困惑度
【发布时间】:2020-03-31 04:59:06
【问题描述】:

我刚刚使用 pytorch 训练了一个 LSTM 语言模型。类的主体是这样的:

class LM(nn.Module):
    def __init__(self, n_vocab, 
                       seq_size, 
                       embedding_size, 
                       lstm_size, 
                       pretrained_embed):

        super(LM, self).__init__()
        self.seq_size = seq_size
        self.lstm_size = lstm_size
        self.embedding = nn.Embedding.from_pretrained(pretrained_embed, freeze = True)
        self.lstm = nn.LSTM(embedding_size,
                            lstm_size,
                            batch_first=True)
        self.fc = nn.Linear(lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)

        return logits, state

现在我想写一个函数来计算一个句子的好坏,基于训练的语言模型(一些分数,比如困惑等)

我有点困惑,我不知道该如何计算。
类似的样本会很有用。

【问题讨论】:

    标签: python nlp pytorch language-model


    【解决方案1】:

    使用交叉熵损失时,您只需使用指数函数torch.exp() 从损失中计算困惑度。
    (pytorch cross-entropy also uses the exponential function resp. log_n)

    所以这里只是一些虚拟的例子:

    import torch
    import torch.nn.functional as F
    num_classes = 10
    batch_size  = 1
    
    # your model outputs / logits
    output      = torch.rand(batch_size, num_classes) 
    
    # your targets
    target      = torch.randint(num_classes, (batch_size,))
    
    # getting loss using cross entropy
    loss        = F.cross_entropy(output, target)
    
    # calculating perplexity
    perplexity  = torch.exp(loss)
    print('Loss:', loss, 'PP:', perplexity)  
    

    在我的例子中,输出是:

    Loss: tensor(2.7935) PP: tensor(16.3376)
    

    如果您想获得每个单词丢失所需的每个单词的困惑度,您只需要注意这一点。

    这是一个语言模型的简洁示例,它可能会从输出中计算出困惑度:

    https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/language_model/main.py#L30-L50

    【讨论】:

    • 感谢您的回答。我的问题有点不同,因为我只想给出一个句子(一个标记列表)作为输入,并得到一个分数作为输出。在这种情况下,我应该将句子和移位句子作为示例代码中的 outputtarget 给出吗?
    • @P.Alipoor 是的,当查看索引为 i 的令牌时,目标应该是令牌 i+1
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2020-01-02
    • 1970-01-01
    • 2017-06-12
    • 1970-01-01
    • 2021-04-06
    • 2022-07-17
    相关资源
    最近更新 更多