【发布时间】:2020-04-09 22:19:48
【问题描述】:
我正在尝试使用 sklearn.pipeline.Pipeline 在 python 中定义一个管道来执行 3 个步骤:预处理、预测和后处理。最终目标是定义一个谷歌云函数,我只需传递 joblib 模型并获得该标签的预测标签和预测概率。
我成功地用前 2 个步骤定义了管道,它工作正常。但是,当我尝试包含第三个(后处理)步骤时,我会收到错误消息。我尝试了各种方法并收到不同的错误消息。
在下面的代码中,如果我从管道中删除 ('proba', FunctionTransformer(findProba()) 一切正常。我似乎无法弄清楚如何将后处理步骤包含到我的管道中。
Scikit-learn 将管道类(参见https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html)定义为:
带有最终估计器的转换管道。
依次应用变换列表和最终估计器。管道的中间步骤必须是“变换”,即它们必须实现拟合和变换方法。最终的估计器只需要实现拟合。管道中的转换器可以使用内存参数进行缓存。
阅读此定义后,我开始想知道是否可以在估算器之后包含一个步骤。但就我而言,我真的需要能够返回课程(在我的情况下为 konto)以及获得该案例的概率(proba)。如果我在第二步之后停止,我将无法计算并返回在线预测期间的概率。
我包含代码摘要以显示我在做什么:
from nltk import word_tokenize
from nltk.corpus import stopwords
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.pipeline import Pipeline
from sklearn.naive_bayes import MultinomialNB
from datetime import date
import time
import warnings
warnings.filterwarnings('ignore')
def findProba(model,Input_Text):
Input_Text = [Input_Text]
Y_predicted = model.predict(Input_Text)
Y_predict_proba = model.predict_proba(Input_Text)
max_proba_rows = np.amax(Y_predict_proba, axis=1)*100
round_off_proba = np.around(max_proba_rows, decimals = 1)
d = dict()
d['Konto'] = Y_predicted[0]
d['proba'] = round_off_proba[0]
return d
df_total = pd.read_csv('dataset_mars2019_trimmed_mapped.csv')
df=df_total.sample(frac=0.001, random_state=1)
X_train, X_test, y_train, y_test = train_test_split(df['Input_Data'], df['LABEL'], random_state = 0, test_size=0.25)
text_clf = Pipeline([('tfidf', TfidfVectorizer()),
('clf', MultinomialNB()),
('proba', FunctionTransformer(findProba()),
])
_ = text_clf.fit(X_train, y_train)
from sklearn.externals import joblib
joblib.dump(text_clf, 'model.joblib')
【问题讨论】:
标签: scikit-learn pipeline post-processing