【问题标题】:Calculating F1 score, precision, recall in tfhub retraining script在 tfhub 再训练脚本中计算 F1 分数、精度、召回率
【发布时间】:2018-11-14 23:36:45
【问题描述】:

我正在使用 tensorflow hub 进行图像再训练分类任务。 tensorflow 脚本retrain.py 默认计算cross_entropy 和accuracy。

train_accuracy, cross_entropy_value = sess.run([evaluation_step, cross_entropy],feed_dict={bottleneck_input: train_bottlenecks, ground_truth_input: train_ground_truth})

我想获得 F1 分数、准确率、召回率和混淆矩阵。如何使用此脚本获取这些值?

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    下面我介绍了一种使用 scikit-learn 包计算所需指标的方法。

    您可以使用precision_recall_fscore_support 方法计算F1 分数、精度和召回率,使用confusion_matrix 方法计算混淆矩阵:

    from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
    

    这两种方法都采用两个类似一维数组的对象,它们分别存储地面实况和预测标签。

    在提供的代码中,训练数据的真实标签存储在10541060 行中定义的train_ground_truth 变量中,而validation_ground_truth 存储验证数据的真实标签并定义在线1087

    计算预测类标签的张量由add_evaluation_step 函数定义并返回。您可以修改1034 行以捕获该张量对象:

    evaluation_step, prediction = add_evaluation_step(final_tensor, ground_truth_input)
    # now prediction stores the tensor object that 
    # calculates predicted class labels
    

    现在您可以更新行 1076 以便在调用 sess.run() 时评估 prediction

    train_accuracy, cross_entropy_value, train_predictions = sess.run(
        [evaluation_step, cross_entropy, prediction],
        feed_dict={bottleneck_input: train_bottlenecks,
                   ground_truth_input: train_ground_truth})
    
    # train_predictions now stores class labels predicted by model
    
    # calculate precision, recall and F1 score
    (train_precision,
     train_recall,
     train_f1_score, _) = precision_recall_fscore_support(y_true=train_ground_truth,
                                                          y_pred=train_predictions,
                                                          average='micro')
    # calculate confusion matrix
    train_confusion_matrix = confusion_matrix(y_true=train_ground_truth,
                                              y_pred=train_predictions)
    

    同样,您可以通过修改1095 行来计算验证子集的指标:

    validation_summary, validation_accuracy, validation_predictions = sess.run(
        [merged, evaluation_step, prediction],
        feed_dict={bottleneck_input: validation_bottlenecks,
                   ground_truth_input: validation_ground_truth})
    
    # validation_predictions now stores class labels predicted by model
    
    # calculate precision, recall and F1 score
    (validation_precision,
     validation_recall,
     validation_f1_score, _) = precision_recall_fscore_support(y_true=validation_ground_truth,
                                                               y_pred=validation_predictions,
                                                               average='micro')
    # calculate confusion matrix
    validation_confusion_matrix = confusion_matrix(y_true=validation_ground_truth,
                                                   y_pred=validation_predictions)
    

    最后,代码调用run_final_eval 在测试数据上评估训练模型。在这个函数中,predictiontest_ground_truth 已经定义好了,所以你只需要包含代码来计算所需的指标:

    test_accuracy, predictions = eval_session.run(
        [evaluation_step, prediction],
        feed_dict={
            bottleneck_input: test_bottlenecks,
            ground_truth_input: test_ground_truth
        })
    
    # calculate precision, recall and F1 score
    (test_precision,
     test_recall,
     test_f1_score, _) = precision_recall_fscore_support(y_true=test_ground_truth,
                                                         y_pred=predictions,
                                                         average='micro')
    # calculate confusion matrix
    test_confusion_matrix = confusion_matrix(y_true=test_ground_truth,
                                             y_pred=predictions)
    

    请注意,提供的代码通过设置average='micro' 计算全局 F1 分数。 scikit-learn 包支持的不同平均方法在User Guide 中进行了描述。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2023-01-05
      • 2017-05-28
      • 2017-06-21
      • 2012-01-19
      • 2019-11-16
      • 2020-04-16
      • 2012-11-26
      • 2023-03-29
      相关资源
      最近更新 更多