【问题标题】:Tensorflow: Compute Precision, Recall, F1 ScoreTensorflow:计算精度、召回率、F1 分数
【发布时间】:2022-01-05 08:20:05
【问题描述】:

我从 Huggingface 构建了一个 BERT 模型(Bert-base-multilingual-cased),并希望评估该模型的精度、召回率和 F1 分数以及准确度,因为准确度并不总是评估的最佳指标。

Here 是我为我的用例修改的示例笔记本。

创建训练/测试数据:

from transformers import BertTokenizer, TFBertModel, TFBertForSequenceClassification

TEST_SPLIT = 0.1
BATCH_SIZE = 2

train_size = int(len(x) * (1-TEST_SPLIT))

tfdataset = tfdataset.shuffle(len(x))
tfdataset_train = tfdataset.take(train_size)
tfdataset_test = tfdataset.skip(train_size)

tfdataset_train = tfdataset_train.batch(BATCH_SIZE)
tfdataset_test = tfdataset_test.batch(BATCH_SIZE)

构建模型:

MODEL_NAME = 'bert-base-multilingual-cased'
N_EPOCHS = 2

model = TFBertForSequenceClassification.from_pretrained(MODEL_NAME)
optimizer = optimizers.Adam(learning_rate=3e-5)
loss = losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])

model.fit(tfdataset_train, batch_size=BATCH_SIZE, epochs=N_EPOCHS)

示例输出:

All model checkpoint layers were used when initializing TFBertForSequenceClassification.

Some layers of TFBertForSequenceClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1/2
415/415 [==============================] - 741s 2s/step - loss: 0.6652 - accuracy: 0.6321
Epoch 2/2
415/415 [==============================] - 717s 2s/step - loss: 0.6619 - accuracy: 0.6429
<keras.callbacks.History at 0x7fc970d72750>

评估:

benchmarks = model.evaluate(tfdataset_test, return_dict=True, batch_size=BATCH_SIZE)
print(benchmarks)

示例输出:

93/93 [==============================] - 42s 404ms/step - loss: 0.6536 - accuracy: 0.6108
{'loss': 0.6535539627075195, 'accuracy': 0.6108108162879944}

有了这个,我就得到了准确度分数。但是,我想要一份包含所有提到的指标的分类报告。

有人知道如何使用此类“tfdatasets”吗?

提前致谢!

【问题讨论】:

  • 由于您使用的是 keras api,因此您可以在代码的指标部分中添加,看看这里:keras.io/api/metrics

标签: python tensorflow machine-learning huggingface-transformers bert-language-model


【解决方案1】:

最简单的方法是使用tensorflow-addons 以及属于tf main/base 包的指标。

       #pip install tensorflow-addons
       import tensorflow as tf
       import tensorflow_addons as tfa

       ....

       model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.00001),
                     loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                     metrics=[tf.keras.metrics.Accuracy(),
                              tf.keras.metrics.Precision(),
                              tf.keras.metrics.Recall(),
                              tfa.metrics.F1Score(num_classes=nb_classes,
                                                  average='macro',
                                                  threshold=0.5))

【讨论】:

  • 将其更改为此时,我收到以下错误消息:“ValueError:Shapes (None, 1) 和 (None, 2) 不兼容”。您可能知道其中的原因吗?
  • 因为您可能正在执行二进制分类而不是分类分类。
  • 现在它打印出以下错误消息:“ValueError: logitslabels 必须具有相同的形状,收到 ((None, 2) vs (None, 1))”。所以我想我必须弄清楚如何重塑我的 tfdataset
【解决方案2】:

这对我有用(找到here):

from keras import backend as K

def recall_m(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = 
    K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + 
    K.epsilon())
    return recall

def precision_m(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = 
    K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision

def f1_m(y_true, y_pred):
    precision = precision_m(y_true, y_pred)
    recall = recall_m(y_true, y_pred)
    return 2*((precision*recall)/(precision+recall+K.epsilon()))

# compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['acc',f1_m,precision_m, recall_m])

【讨论】:

    猜你喜欢
    • 2017-05-28
    • 2017-06-21
    • 2019-11-16
    • 1970-01-01
    • 2020-04-16
    • 2012-01-19
    • 2012-11-26
    • 2023-03-29
    • 2021-09-01
    相关资源
    最近更新 更多