【发布时间】:2021-01-09 23:30:03
【问题描述】:
我正在使用 Scikit-Learn 的逻辑回归算法来执行数字分类。我使用的数据集是 Scikit-Learn 的 load_digits。
以下是我的代码的简化版本:
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import learning_curve
from sklearn.datasets import load_digits
digits = load_digits()
model = LogisticRegression(solver ='lbfgs',
penalty = 'none',
max_iter = 1e5,
multi_class = 'auto')
model.fit(digits.data, digits.target)
predictions = model.predict(digits.data)
df_cm = pd.DataFrame(confusion_matrix(digits.target, predictions))
ax = sns.heatmap(df_cm, annot = True, cbar = False, cmap = 'Blues_r', fmt='d', annot_kws = {"size": 10})
ax.set_ylim(0,10)
plt.title("Confusion Matrix")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
train_size = [0.2, 0.4, 0.6, 0.8, 1]
training_size, training_score, validation_score = learning_curve(model, digits.data, digits.target, cv = 5,
train_sizes = train_size, scoring = 'neg_mean_squared_error')
training_scores_mean = - training_score.mean(axis = 1)
validation_score_mean = - validation_score.mean(axis = 1)
plt.plot(training_size, validation_score_mean)
plt.plot(training_size, training_scores_mean)
plt.legend(["Validation error", "Training error"])
plt.ylabel("MSE")
plt.xlabel("Training set size")
plt.show()
### EDIT ###
# With L2 regularization
model = LogisticRegression(solver ='lbfgs',
penalty = 'l2', # Changing penality to l2
max_iter = 1e5,
multi_class = 'auto')
model.fit(digits.data, digits.target)
predictions = model.predict(digits.data)
df_cm = pd.DataFrame(confusion_matrix(digits.target, predictions))
ax = sns.heatmap(df_cm, annot = True, cbar = False, cmap = 'Blues_r', fmt='d', annot_kws = {"size": 10})
ax.set_ylim(0,10)
plt.title("Confusion Matrix with L2 regularization")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
training_size, training_score, validation_score = learning_curve(model, digits.data, digits.target, cv = 5,
train_sizes = train_size, scoring = 'neg_mean_squared_error')
training_scores_mean = - training_score.mean(axis = 1)
validation_score_mean = - validation_score.mean(axis = 1)
plt.plot(training_size, validation_score_mean)
plt.plot(training_size, training_scores_mean)
plt.legend(["Validation error", "Training error"])
plt.title("Learning curve with L2 regularization")
plt.ylabel("MSE")
plt.xlabel("Training set size")
plt.show()
# With L2 regularization and best C
from sklearn.model_selection import GridSearchCV
C = {'C': [1e-3, 1e-2, 1e-1, 1, 10]}
model_l2 = GridSearchCV(LogisticRegression(random_state = 0, solver ='lbfgs', penalty = 'l2', max_iter = 1e5, multi_class = 'auto'),
param_grid = C, cv = 5, iid = False, scoring = 'neg_mean_squared_error')
model_l2.fit(digits.data, digits.target)
best_C = model_l2.best_params_.get("C")
print(best_C)
model_reg = LogisticRegression(solver ='lbfgs',
penalty = 'l2',
C = best_C,
max_iter = 1e5,
multi_class = 'auto')
model_reg.fit(digits.data, digits.target)
predictions = model_reg.predict(digits.data)
df_cm = pd.DataFrame(confusion_matrix(digits.target, predictions))
ax = sns.heatmap(df_cm, annot = True, cbar = False, cmap = 'Blues_r', fmt='d', annot_kws = {"size": 10})
ax.set_ylim(0,10)
plt.title("Confusion Matrix with L2 regularization and best C")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
training_size, training_score, validation_score = learning_curve(model_reg, digits.data, digits.target, cv = 5,
train_sizes = train_size, scoring = 'neg_mean_squared_error')
training_scores_mean = - training_score.mean(axis = 1)
validation_score_mean = - validation_score.mean(axis = 1)
plt.plot(training_size, validation_score_mean)
plt.plot(training_size, training_scores_mean)
plt.legend(["Validation error", "Training error"])
plt.title("Learning curve with L2 regularization and best C")
plt.ylabel("MSE")
plt.xlabel("Training set size")
plt.show()
从训练数据的混淆矩阵和使用 learning_curve 生成的最后一个图中可以看出,训练集上的误差始终为 0:
在我看来,该模型严重过度拟合,我无法理解它。我也尝试过使用 MNIST 数据集,但发生了同样的事情。
我该如何解决这个问题?
-- 编辑--
在代码上方添加 L2 正则化,并为超参数 C 设置最佳值。
使用 L2 正则化,模型仍然过拟合数据:
Learning Curve with L2 regularization here
使用最好的 C 超参数,训练数据上的误差不再为零,但算法仍然过拟合:
Learning Curve with L2 regularization here and best C here
还是不明白怎么回事……
【问题讨论】:
标签: machine-learning scikit-learn classification logistic-regression