【发布时间】:2020-06-21 04:27:37
【问题描述】:
我运行了这段代码,但 lr.fit 行似乎有错误。有谁知道怎么做?
from sklearn.model_selection import cross_val_predict
from sklearn.model_selection import cross_val_score
from sklearn import linear_model
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import pandas as pd
df = pd.read_csv('2019.csv')
df1 = pd.DataFrame(df,columns=['GDP per capita', 'Social support'])
lr = LogisticRegression()
columns = ['GDP per capita', 'Social support']
X = df[columns]
y = df["Score"]
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.20,random_state=0)
lr.fit(X_train,y_train)
predictions = lr.predict(X_test)
accuracy = accuracy_score(y_test, predictions)
print(accuracy)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-5-afa10dbaa367> in <module>
19 X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.30,random_state=0)
20
---> 21 lr.fit(X_train,y_train)
22 predictions = lr.predict(X_test)
23 accuracy = accuracy_score(y_test, predictions)
~/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/_logistic.py in fit(self, X, y, sample_weight)
1526 X, y = check_X_y(X, y, accept_sparse='csr', dtype=_dtype, order="C",
1527 accept_large_sparse=solver != 'liblinear')
-> 1528 check_classification_targets(y)
1529 self.classes_ = np.unique(y)
1530 n_samples, n_features = X.shape
~/opt/anaconda3/lib/python3.7/site-packages/sklearn/utils/multiclass.py in check_classification_targets(y)
167 if y_type not in ['binary', 'multiclass', 'multiclass-multioutput',
168 'multilabel-indicator', 'multilabel-sequences']:
--> 169 raise ValueError("Unknown label type: %r" % y_type)
170
171
ValueError: Unknown label type: 'continuous'
最上面是完整的调试错误,我只有在 X 和 y 旁边执行 .astype(int) 时才让它工作。否则如果我不这样做,就会出现你看到的错误。
【问题讨论】:
-
回溯在哪里?
-
它说我不能使用连续变量,但我的大多数列值都是连续的。所以我有点卡住了
-
你能分享完整的回溯吗?
-
彻底找出错误所在。追溯或数据本身在哪里?
-
大家好,我在编辑中包含了错误,抱歉@UchihaAJ
标签: python pandas numpy matplotlib logistic-regression