【问题标题】:Unexpected behavior of ML Logistic Regression?ML Logistic 回归的意外行为?
【发布时间】:2022-01-25 17:33:32
【问题描述】:

我是 AI 和 ML 的新手,如果这是一个愚蠢的问题,我深表歉意。
我在阅读 Logistic Regression,发现它是一个分类监督的 ML 模型。

所以我尝试编写一个示例来尝试一下。我的想法是看看程序是否能够找出我建立的标签 (Y) 背后的“规则”,即“Y = 1 当且仅当 X1 OR X2 是 3 的倍数但不是两者兼有时, 否则为 0"

但正如您所见,准确性非常差。难道我做错了什么?我是否误解了逻辑回归的概念?

数据集

3,1,1
2,3,1
1,1,0
2,4,0
5,6,1
9,3,1
8,9,1
5,5,0
9,9,0
5,7,0
3,3,0
5,3,1
2,4,0
7,7,0
4,9,1
7,3,1
6,2,1
8,1,0
6,4,0
9,4,1

代码

from sklearn.linear_model import LogisticRegression
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn import metrics
col_names = ['x1', 'x2', 'y']
multi3 = pd.read_csv("1.csv", header=None, names=col_names)
feature_cols = ['x1', 'x2']
X = multi3[feature_cols]
y = multi3.y
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=0)
logreg = LogisticRegression()
logreg.fit(X_train, y_train)
y_pred = logreg.predict(X_test)
cnf_matrix = metrics.confusion_matrix(y_test, y_pred)
print(cnf_matrix)
print("Accuracy:", metrics.accuracy_score(y_test, y_pred))
print("Precision:", metrics.precision_score(y_test, y_pred))
print("Recall:", metrics.recall_score(y_test, y_pred))

输出

[[1 2]
 [1 1]]
Accuracy: 0.4
Precision: 0.3333333333333333
Recall: 0.5

编辑

下面我评论的源代码。

【问题讨论】:

  • 有趣的是,当我将规则更改为“如果 X1>5 或 X2>5 时 Y=1,否则为 0”我得到了 92% 的准确率(上面的屏幕截图)
  • 您还将观察次数从 20 更改为 1000。

标签: python machine-learning scikit-learn artificial-intelligence logistic-regression


【解决方案1】:

您可以可视化您的数据:

multi3.plot.scatter(x = "x1",y="x2", c = "y",cmap="viridis")

您可以看到您的两个不同类别(0 或 1)之间没有明显的区别。因此,即使使用较小的测试集,您获得的准确度也会很低,因为 x1 和 x2 在区分标签方面根本没有用。

在您发布的代码中,您在更大的数据集和模拟数据上进行了处理,如果我们做类似的事情,

import numpy as np
np.random.seed(123)
df = pd.DataFrame(np.random.randint(0,10,(60,2)),columns=['x1', 'x2'])
df['y'] = ((df['x1']>5) & (df['x2'] > 5)).astype(int)

logreg = LogisticRegression()
logreg.fit(df[['x1','x2']], df['y'])
y_pred = logreg.predict(df[['x1','x2']])
cnf_matrix = metrics.confusion_matrix(df['y'], y_pred)
cnf_matrix

array([[49,  2],
       [ 2,  7]])

当然,你可以看到有分离:

我的猜测是原始数据集是错误的,或者与您发布的图片无关。

【讨论】:

    相关资源