【问题标题】:Top 3 classes in predict_proba()predict_proba() 中的前 3 个类
【发布时间】:2020-11-18 03:35:18
【问题描述】:

我正在研究一个多类文本分类问题,该问题需要具有相应概率的前 3 个预测标签。我可以使用sklearn predict_proba(),但很难像 table A 那样格式化输出。我的代码如下:

cv = StratifiedKFold(n_splits = 10, random_state = 42, shuffle = None)

pipeline_sgd = Pipeline([
     ('vect', CountVectorizer()),
     ('tfdif', TfidfTransformer()),
     ('nb', CalibratedClassifierCV(base_estimator = SGDClassifier(), cv=cv)),
])
Model = pipeline_sgd.fit(X_train, y_train)

n_top_labels = 3
probas = model.predict_probas(test["text"])
top_n_lables_idx = probas.argsort()[::-1][:n_top_lables]
top_n_probs = probas[top_n_lables_idx]
top_n_labels = label_encoder.inverse_transform(top_n_lables_idx.ravel())

results = list(zip(top_n_labels, top_n_probas))

 

输出:

[(A, .80),
 (B, .10),
 (C, .10)]

我在上述输出中遇到的挑战是它没有为我提供每行文本的前 3 个标签/概率。例如,当我对一组新文档(文本)进行推理时,我只得到一个输出,而不是每个文档(行)的输出。

我遇到的第二个挑战是,当我使用 pd.Dataframe(data = results) 将其插入数据帧时,我得到以下信息:

|   | 0 | 1               |
|---|---|-----------------|
| 0 | A | [[.80,.10,.10]] |
| 1 | B | [[.85,.10,.05]] |
| 2 | C | [[.70,.20,.10]] |

应该是:

|   | 0     | 1               |
|---|-------|-----------------|
| 0 | A,B,C | [[.80,.10,.10]] |
| 1 | B,C,A | [[.85,.10,.05]] |
| 2 | C,B,A | [[.70,.20,.10]] |

表 A

| Text                                       | Predicted labels | Probabilities  |
|--------------------------------------------|------------------|----------------|
| Hello  World!                              | A,B,C            | [.80,.10,10]   |
| Have a nice Day!                           | B,C,A            | [.90,.05,05]   |
| It's a wonderful day in the neighborhood.  | C,A,B            | [.80,.10,10]   |

【问题讨论】:

    标签: python machine-learning scikit-learn


    【解决方案1】:

    当我运行您的代码时,top_n_probs 的形状非常奇怪,我发现很难找回标签。 argsort 和调用排序值的代码似乎有点奇怪。

    下面我写了一个应该可以工作的快速实现。

    举个例子dataset

    from sklearn.model_selection import StratifiedKFold
    from sklearn.pipeline import Pipeline
    from sklearn.feature_extraction.text import CountVectorizer,TfidfTransformer
    from sklearn.calibration import CalibratedClassifierCV
    from sklearn.linear_model import SGDClassifier
    
    import pandas as pd
    import numpy as np
    
    df = pd.read_csv('./smsspamcollection//SMSSpamCollection', sep='\t', names=["label", "message"])
    df['label'][df['label']=='ham'] = np.random.choice(['hamA','hamB'],np.sum(df['label']=='ham'))
    X_train = df['message']
    y_train = df['label']
    

    我的标签如下所示:

    df['label'].value_counts()
    
    hamB    2425
    hamA    2400
    spam     747
    

    并运行您的代码进行拟合:

    cv = StratifiedKFold(n_splits = 10, random_state = 42, shuffle = True)
    
    pipeline_sgd = Pipeline([
         ('vect', CountVectorizer()),
         ('tfdif', TfidfTransformer()),
         ('nb', CalibratedClassifierCV(base_estimator = SGDClassifier(), cv=cv)),
    ])
    
    model = pipeline_sgd.fit(X_train, y_train)
    

    这应该可行:

    n_top_labels = 3
    probas = model.predict_proba(X_train[:5])
    top_n_lables_idx = np.argsort(-probas)
    top_n_probs = np.round(-np.sort(-probas),3)
    top_n_labels = [model.classes_[i] for i in top_n_lables_idx]
    
    results = list(zip(top_n_labels, top_n_probs))
    
    pd.DataFrame(results)
    
        0   1
    0   [hamB, hamA, spam]  [0.608, 0.38, 0.012]
    1   [hamA, hamB, spam]  [0.605, 0.391, 0.004]
    2   [spam, hamB, hamA]  [0.603, 0.212, 0.185]
    3   [hamB, hamA, spam]  [0.521, 0.478, 0.001]
    4   [hamB, hamA, spam]  [0.645, 0.352, 0.003]
    

    【讨论】:

      【解决方案2】:

      更新已接受的答案。

      n = 3
      
      probas = model.predict_proba(X_train)
      top_n_lables_idx = np.argsort(-probas, axis=1)[:, :n]
      top_n_probs = np.round(-np.sort(-probas),3)[:, :n]
      top_n_labels = [model.classes_[i] for i in top_n_lables_idx]
          
      results = list(zip(top_n_labels, top_n_probs))
      
      pd.DataFrame(results)
      

      这确保我在两列中都获得前 3 名。

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 2015-03-27
        • 1970-01-01
        • 2021-07-01
        • 2017-08-29
        • 2015-02-17
        • 2018-03-25
        • 2018-03-03
        • 2020-09-06
        相关资源
        最近更新 更多