【问题标题】:how to predict a character based on character based RNN model?如何基于基于字符的 RNN 模型预测字符?
【发布时间】:2021-02-05 13:00:08
【问题描述】:

我想创建一个预测函数来完成“句子”的一部分 这里使用的模型是基于字符的 RNN(LSTM)。我们应该采取哪些步骤? 我试过了,但我不能输入句子

 def generate(self) -> Tuple[List[Token], torch.tensor]:

    start_symbol_idx = self.vocab.get_token_index(START_SYMBOL, 'tokens')
   # print(start_symbol_idx)
    end_symbol_idx = self.vocab.get_token_index(END_SYMBOL, 'tokens')
    padding_symbol_idx = self.vocab.get_token_index(DEFAULT_PADDING_TOKEN, 'tokens')

    log_likelihood = 0.
    words = []
    state = (torch.zeros(1, 1, self.hidden_size), torch.zeros(1, 1, self.hidden_size))

    word_idx = start_symbol_idx

    for i in range(self.max_len):
        tokens = torch.tensor([[word_idx]])

        embeddings = self.embedder({'tokens': tokens})
        output, state = self.rnn._module(embeddings, state)
        output = self.hidden2out(output)

        log_prob = torch.log_softmax(output[0, 0], dim=0)

        dist = torch.exp(log_prob)

        word_idx = start_symbol_idx

        while word_idx in {start_symbol_idx, padding_symbol_idx}:
            word_idx = torch.multinomial(
                dist, num_samples=1, replacement=False).item()

        log_likelihood += log_prob[word_idx]

        if word_idx == end_symbol_idx:
            break

        token = Token(text=self.vocab.get_token_from_index(word_idx, 'tokens'))
        words.append(token)

    return words, log_likelihood,start_symbol_idx

【问题讨论】:

  • 你试过什么?您是否在网上找到任何解决此问题的资源?
  • 这段代码对我有帮助吗?
  • 你的目标是什么?您是否有想要从中生成的经过训练的模型?你想训练一个模型来生成字符吗?您在寻找教程吗?你想知道如何训练/使用机器学习模型吗?
  • 我已经有一个generate函数,代码如上,它不是根据用户给定的输入来完成一个句子的。

标签: nlp lstm recurrent-neural-network


【解决方案1】:

这里有两个关于如何使用机器学习库生成文本TensorflowPyTorch的教程。

【讨论】:

    【解决方案2】:

    这段代码sn-p是allennlp“语言模型”教程的一部分,这里定义了generate函数来计算token的概率,并根据模型输出的最大似然度找到最好的token和token序列,完整代码在下面的 colab notebook 中,您可以参考:https://colab.research.google.com/github/mhagiwara/realworldnlp/blob/master/examples/generation/lm.ipynb#scrollTo=8AU8pwOWgKxE 在训练了使用此功能的语言模型之后,您可以说:

    for _ in range(50):
      tokens, _ = model.generate()
      print(''.join(token.text for token in tokens))
    

    【讨论】:

      猜你喜欢
      • 2021-07-18
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2020-08-21
      • 1970-01-01
      • 2019-09-19
      • 1970-01-01
      相关资源
      最近更新 更多