【发布时间】:2018-11-18 00:53:48
【问题描述】:
我正在尝试学习如何为分类器找到最佳参数。所以,我使用 GridSearchCV 来解决多类分类问题。在Does not GridSearchCV support multi-class? 上生成了一个虚拟代码,我只是将该代码与 n_classes=3 一起使用。
import numpy as np
from sklearn.datasets import make_classification
from sklearn.preprocessing import StandardScaler,label_binarize
from sklearn.svm import SVC
from sklearn.pipeline import make_pipeline
from sklearn.grid_search import GridSearchCV
from sklearn.metrics import accuracy_score, recall_score, f1_score, roc_auc_score, make_scorer
X, y = make_classification(n_samples=3000, n_features=10, weights=[0.1, 0.9, 0.3],n_classes=3, n_clusters_per_class=1,n_informative=2)
pipe = make_pipeline(StandardScaler(), SVC(kernel='rbf', class_weight='auto'))
param_space = dict(svc__C=np.logspace(-5,0,5), svc__gamma=np.logspace(-2, 2, 10))
f1_score
my_scorer = make_scorer(f1_score, greater_is_better=True)
gscv = GridSearchCV(pipe, param_space, scoring=my_scorer)
我正在尝试按照此处Scikit-learn GridSearch giving "ValueError: multiclass format is not supported" error 的建议进行一次性编码。此外,有时会有像Toxic Comment Classification dataset on Kaggle 这样的数据集,它会给你二值化标签。
y = label_binarize(y, classes=[0, 1, 2])
for i in classes:
gscv.fit(X, y[i])
print gscv.best_params_
我得到:
ValueError: bad input shape (2000L, 3L)
我不确定为什么会收到此错误。我的目标是为多类分类问题找到最佳参数。
【问题讨论】:
标签: python scikit-learn svm grid-search