【发布时间】:2021-03-26 00:05:28
【问题描述】:
我试图通过将函数 OneVsRestClassifier 与我自己的实现进行比较来验证我是否正确理解了 SVM - OVA(One-versus-All)的工作原理。
在下面的代码中,我在训练阶段实现了num_classes分类器,然后在测试集上对它们进行了测试,并选择了返回最高概率值的那个。
import pandas as pd
import numpy as np
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score,classification_report
from sklearn.preprocessing import scale
# Read dataset
df = pd.read_csv('In/winequality-white.csv', delimiter=';')
X = df.loc[:, df.columns != 'quality']
Y = df.loc[:, df.columns == 'quality']
my_classes = np.unique(Y)
num_classes = len(my_classes)
# Train-test split
np.random.seed(42)
msk = np.random.rand(len(df)) <= 0.8
train = df[msk]
test = df[~msk]
# From dataset to features and labels
X_train = train.loc[:, train.columns != 'quality']
Y_train = train.loc[:, train.columns == 'quality']
X_test = test.loc[:, test.columns != 'quality']
Y_test = test.loc[:, test.columns == 'quality']
# Models
clf = [None] * num_classes
for k in np.arange(0,num_classes):
my_model = SVC(gamma='auto', C=1000, kernel='rbf', class_weight='balanced', probability=True)
clf[k] = my_model.fit(X_train, Y_train==my_classes[k])
# Prediction
prob_table = np.zeros((len(Y_test), num_classes))
for k in np.arange(0,num_classes):
p = clf[k].predict_proba(X_test)
prob_table[:,k] = p[:,list(clf[k].classes_).index(True)]
Y_pred = prob_table.argmax(axis=1)
print("Test accuracy = ", accuracy_score( Y_test, Y_pred) * 100,"\n\n")
测试精度等于0.21,而使用函数OneVsRestClassifier时,返回0.59。为了完整起见,我还报告了其他代码(预处理步骤与之前相同):
....
clf = OneVsRestClassifier(SVC(gamma='auto', C=1000, kernel='rbf', class_weight='balanced'))
clf.fit(X_train, Y_train)
Y_pred = clf.predict(X_test)
print("Test accuracy = ", accuracy_score( Y_test, Y_pred) * 100,"\n\n")
我自己的 SVM - OVA 实现有什么问题吗?
【问题讨论】:
-
我猜你不应该为自己使用
predict_proba方法和在内置版本中使用predict方法。我还猜想accuracy_score函数适用于预测而不是预测概率... -
@AlexanderRiedel
Accuracy_score确实适用于预测。我不认为predict_proba在使用predict方法时可以改变SVC 的预测。我认为predict_proba只是将概率值与决策函数相关联...... -
是的,但是
predict_proba不返回预测,而是返回概率矩阵...我不明白你为什么不只预测类别而是预测概率
标签: python scikit-learn svm multiclass-classification