【问题标题】:How to add probabilities to model.predict output?如何将概率添加到 model.predict 输出?
【发布时间】:2021-05-26 10:57:21
【问题描述】:

我已经按照this tutorial 构建了一个有效的分类模型。

本教程仅输出预测的类别名称。我希望它输出类别名称及其概率,我只想输出高于某个概率的类别。例如,我只想要超过 0.5 的类别

这是用于访问模型的函数:

import pickle
import numpy as np
category_model_path="categorymodel.pkl"
category_transformer_path="categorytransformer.pkl"
sentiment_model_path="sentimentmodel.pkl"
sentiment_transformer_path="sentimenttransformer.pkl"

def get_top_k_predictions(model,X_test,k):
    
    # get probabilities instead of predicted labels, since we want to collect top 3
    np.set_printoptions(suppress=True)
    probs = model.predict_proba(X_test)

    # GET TOP K PREDICTIONS BY PROB - note these are just index
    best_n = np.argsort(probs, axis=1)[:,-k:]
    
    # GET CATEGORY OF PREDICTIONS
    preds=[[model.classes_[predicted_cat] for predicted_cat in prediction] for prediction in best_n]
    
    preds=[ item[::-1] for item in preds]
    
    return preds

category_loaded_model = pickle.load(open(category_model_path, 'rb'))
category_loaded_transformer = pickle.load(open(category_transformer_path, 'rb'))

sentiment_loaded_model = pickle.load(open(sentiment_model_path, 'rb'))
sentiment_loaded_transformer = pickle.load(open(sentiment_transformer_path, 'rb'))

那么这段代码就是用来调用函数的:

category_test_features=category_loaded_transformer.transform(["I absolutley loved the organization "])
get_top_k_predictions(category_loaded_model,category_test_features,2)

这是当前的输出:

[['Course Structure', 'Learning Materials']]

概率在函数中计算到probs 变量。我不知道如何只获得超过 0.5 的那些并将它们添加到 preds 输出。

【问题讨论】:

    标签: python machine-learning classification text-classification


    【解决方案1】:

    best_n 数组包含概率数组probs 的索引。您可以像获取标签一样使用它。你可以像这样得到标签概率元组:

    preds = [
        [(model.classes_[predicted_cat], distribution[predicted_cat])
         for predicted_cat in prediction]
        for distribution, prediction in zip(probs, best_n)]
    

    如果您不想返回概率而只想过滤它们,您可以执行以下操作:

    preds=[
        [model.classes_[predicted_cat]
         for predicted_cat in prediction if distribution[predicted_cat] > 0.5]
        for distribution, prediction in zip(probs, best_n)]
    

    【讨论】:

    • 知道如何按标签名称的字母顺序对数组进行排序吗?
    猜你喜欢
    • 2018-10-11
    • 2017-01-07
    • 2018-07-15
    • 2018-07-15
    • 2017-09-08
    • 1970-01-01
    • 2016-03-12
    • 1970-01-01
    • 2016-04-04
    相关资源
    最近更新 更多