【问题标题】:Multiclass Classification and probability prediction多类分类和概率预测
【发布时间】:2018-10-12 08:24:31
【问题描述】:
import pandas as pd
import numpy
from sklearn import cross_validation
from sklearn.naive_bayes import GaussianNB

fi = "df.csv"
# Open the file for reading and read in data
file_handler = open(fi, "r")
data = pd.read_csv(file_handler, sep=",")
file_handler.close()

# split the data into training and test data
train, test = cross_validation.train_test_split(data,test_size=0.6, random_state=0)
# initialise Gaussian Naive Bayes
naive_b = GaussianNB()


train_features = train.ix[:,0:127]
train_label = train.iloc[:,127]

test_features = test.ix[:,0:127]
test_label = test.iloc[:,127]

naive_b.fit(train_features, train_label)
test_data = pd.concat([test_features, test_label], axis=1)
test_data["p_malw"] = naive_b.predict_proba(test_features)

print "test_data\n",test_data["p_malw"]
print "Accuracy:", naive_b.score(test_features,test_label)

我编写了这段代码来接受来自 csv 文件的输入,该文件有 128 列,其中 127 列是特征,第 128 列是类标签。

我想预测样本属于每个类别的概率(有 5 个类别 (1-5))并将其打印到矩阵中并根据预测确定样本类别。 predict_proba() 没有给出所需的输出。请提出必要的更改建议。

【问题讨论】:

  • @mr_mo 你能帮忙吗

标签: python machine-learning scikit-learn naivebayes multiclass-classification


【解决方案1】:

gaussiannb.predict_proba返回模型中每个类的样本的概率。在您的情况下,它应该返回一个结果,其中五列具有与测试数据中相同的行数相同的行。您可以验证哪个列对应于使用Naive_b.classes_的课程。因此,尚不清楚为什么你说这不是所需的输出。也许,您的问题来自您正在为数据帧列分配预测proba的输出。试试:

pred_prob = naive_b.predict_proba(test_features)

而不是

test_data["p_malw"] = naive_b.predict_proba(test_features)

并使用pred_prob.shape验证其形状。第二维应为5。

如果您希望每个样本的预测标签,您可以使用预测方法,然后使用混淆矩阵来查看已正确预测的标签。

from sklearn.metrics import confusion_matrix

naive_B.fit(train_features, train_label)

pred_label = naive_B.predict(test_features)

confusion_m = confusion_matrix(test_label, pred_label)
confusion_m

这里有一些有用的阅读。

sklearn gaussiannb - http://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.GaussianNB.html#sklearn.naive_bayes.GaussianNB.predict_proba

sklearn confusion_matrix - 987654322 @

【讨论】:

    猜你喜欢
    • 2016-01-05
    • 2017-06-12
    • 2019-09-16
    • 2016-07-09
    • 2017-06-12
    • 2015-02-01
    • 2020-05-23
    • 2018-12-24
    • 2022-06-28
    相关资源
    最近更新 更多