【问题标题】:How to train Huggingface TFT5ForConditionalGeneration model?如何训练 Huggingface TFT5ForConditionalGeneration 模型?
【发布时间】:2020-12-15 02:22:24
【问题描述】:

我的代码如下:

batch_size=8
sequence_length=25
vocab_size=100
import tensorflow as tf
from transformers import T5Config, TFT5ForConditionalGeneration
configT5 = T5Config(
    vocab_size=vocab_size,
    d_ff =512, 
)  
model = TFT5ForConditionalGeneration(configT5)

model.compile(
    optimizer = tf.keras.optimizers.Adam(),
    loss = tf.keras.losses.SparseCategoricalCrossentropy()
)
input = tf.random.uniform([batch_size,sequence_length],0,vocab_size,dtype=tf.int32)
labels = tf.random.uniform([batch_size,sequence_length],0,vocab_size,dtype=tf.int32)
input = {'inputs': input, 'decoder_input_ids': input}
model.fit(input, labels)

它会产生错误:

logits 和标签必须具有相同的第一维,得到 logits 形状 [1600,64] 和标签形状 [200] [[节点 sparse_categorical_crossentropy_3/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits (定义在 C:\Users\FA.PROJECTOR-MSK\Google Диск\Colab Notebooks\PoetryTransformer\experiments\TFT5.py:30) ]] [Op:__inference_train_function_25173] 函数调用堆栈: train_function

我不明白 - 为什么模型返回 [1600, 64] 的张量。根据https://huggingface.co/transformers/model_doc/t5.html#tft5forconditionalgeneration模型返回[batch_size, sequence_len, vocab_size]。

【问题讨论】:

    标签: tensorflow huggingface-transformers


    【解决方案1】:

    由于 TFT5ForConditionalGeneration 的 call() 方法的非标准签名,无法调用 fit()。我必须覆盖 train_step() 才能使 TFT5 正常工作。看这里 - https://colab.research.google.com/github/snapthat/TF-T5-text-to-text/blob/master/snapthatT5/notebooks/TF-T5-Datasets%20Training.ipynb#scrollTo=cgxRVn34Z0wb

    【讨论】:

      猜你喜欢
      • 2023-03-10
      • 2022-11-08
      • 1970-01-01
      • 2022-01-03
      • 1970-01-01
      • 2021-01-08
      • 1970-01-01
      • 2022-11-03
      • 2021-03-31
      相关资源
      最近更新 更多