【问题标题】:Evaluating Logistic regression with cross validation使用交叉验证评估逻辑回归
【发布时间】:2017-01-02 22:51:53
【问题描述】:

我想使用交叉验证来测试/训练我的数据集,并评估逻辑回归模型在整个数据集上的性能,而不仅仅是在测试集上(例如 25%)。

这些概念对我来说是全新的,我不太确定这样做是否正确。如果有人能就我出错的正确步骤向我提供建议,我将不胜感激。我的部分代码如下所示。

另外,如何在与当前图表相同的图表上绘制“y2”和“y3”的 ROC?

谢谢

import pandas as pd 
Data=pd.read_csv ('C:\\Dataset.csv',index_col='SNo')
feature_cols=['A','B','C','D','E']
X=Data[feature_cols]

Y=Data['Status'] 
Y1=Data['Status1']  # predictions from elsewhere
Y2=Data['Status2'] # predictions from elsewhere

from sklearn.linear_model import LogisticRegression
logreg=LogisticRegression()
logreg.fit(X_train,y_train)

from sklearn.cross_validation import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

from sklearn import metrics, cross_validation
predicted = cross_validation.cross_val_predict(logreg, X, y, cv=10)
metrics.accuracy_score(y, predicted) 

from sklearn.cross_validation import cross_val_score
accuracy = cross_val_score(logreg, X, y, cv=10,scoring='accuracy')
print (accuracy)
print (cross_val_score(logreg, X, y, cv=10,scoring='accuracy').mean())

from nltk import ConfusionMatrix 
print (ConfusionMatrix(list(y), list(predicted)))
#print (ConfusionMatrix(list(y), list(yexpert)))

# sensitivity:
print (metrics.recall_score(y, predicted) )

import matplotlib.pyplot as plt 
probs = logreg.predict_proba(X)[:, 1] 
plt.hist(probs) 
plt.show()

# use 0.5 cutoff for predicting 'default' 
import numpy as np 
preds = np.where(probs > 0.5, 1, 0) 
print (ConfusionMatrix(list(y), list(preds)))

# check accuracy, sensitivity, specificity 
print (metrics.accuracy_score(y, predicted)) 

#ROC CURVES and AUC 
# plot ROC curve 
fpr, tpr, thresholds = metrics.roc_curve(y, probs) 
plt.plot(fpr, tpr) 
plt.xlim([0.0, 1.0]) 
plt.ylim([0.0, 1.0]) 
plt.xlabel('False Positive Rate') 
plt.ylabel('True Positive Rate)') 
plt.show()

# calculate AUC 
print (metrics.roc_auc_score(y, probs))

# use AUC as evaluation metric for cross-validation 
from sklearn.cross_validation import cross_val_score 
logreg = LogisticRegression() 
cross_val_score(logreg, X, y, cv=10, scoring='roc_auc').mean() 

【问题讨论】:

    标签: python scikit-learn logistic-regression cross-validation


    【解决方案1】:

    你几乎是对的。 cross_validation.cross_val_predict 为您提供整个数据集的预测。您只需要在代码的前面删除logreg.fit。具体来说,它的作用如下: 它将您的数据集划分为n 折叠,并且在每次迭代中,它将其中一个折叠作为测试集并在其余折叠(n-1 折叠)上训练模型。因此,最终您将获得对整个数据的预测。

    让我们用 sklearn 中的一个内置数据集 iris 来说明这一点。该数据集包含 150 个具有 4 个特征的训练样本。 iris['data']Xiris['target']y

    In [15]: iris['data'].shape
    Out[15]: (150, 4)
    

    要通过交叉验证获得对整个集合的预测,您可以执行以下操作:

    from sklearn.linear_model import LogisticRegression
    from sklearn import metrics, cross_validation
    from sklearn import datasets
    iris = datasets.load_iris()
    predicted = cross_validation.cross_val_predict(LogisticRegression(), iris['data'], iris['target'], cv=10)
    print metrics.accuracy_score(iris['target'], predicted)
    
    Out [1] : 0.9537
    
    print metrics.classification_report(iris['target'], predicted) 
    
    Out [2] :
                         precision    recall  f1-score   support
    
                    0       1.00      1.00      1.00        50
                    1       0.96      0.90      0.93        50
                    2       0.91      0.96      0.93        50
    
          avg / total       0.95      0.95      0.95       150
    

    那么,回到您的代码。你只需要这个:

    from sklearn import metrics, cross_validation
    logreg=LogisticRegression()
    predicted = cross_validation.cross_val_predict(logreg, X, y, cv=10)
    print metrics.accuracy_score(y, predicted)
    print metrics.classification_report(y, predicted) 
    

    要在多类分类中绘制 ROC,您可以关注this tutorial,它会为您提供如下内容:

    总的来说,sklearn 有非常好的教程和文档。我强烈推荐阅读他们的tutorial on cross_validation

    【讨论】:

    • ImportError: 无法从“sklearn”导入名称“cross_validation”。 - scikit-learn.org/stable/modules/… 虽然有 cross_val_score,但没有 cross_validation
    • 使用 'from sklearn.model_selection import cross_val_predict' 来解决这个问题
    猜你喜欢
    • 2019-04-20
    • 1970-01-01
    • 1970-01-01
    • 2017-07-07
    • 2017-01-25
    • 2016-11-20
    • 2017-10-23
    • 2021-06-03
    • 2017-09-13
    相关资源
    最近更新 更多