【发布时间】:2018-06-10 22:36:12
【问题描述】:
所以我是机器学习的新手,对这个错误有点困惑:
形状 (1,4) 和 (14,14) 未对齐:4 (dim 1) != 14 (dim 0)
这是完整的错误:
文件“/Users/jim/anaconda3/lib/python3.6/site-packages/sklearn/utils/extmath.py”,第 140 行,位于 safe_sparse_dot 返回 np.dot(a, b)
ValueError:形状 (1,4) 和 (14,14) 未对齐:4 (dim 1) != 14 (dim 0)
我的测试集有 4 行数据,训练集有 14 行数据,如 (1,4) 和 (14,14) 所示。至少我认为是这个意思。
我正在尝试将简单的线性回归拟合到训练集,如下面的代码所示:
# Fit Simple Linear Regression to Training Set
from sklearn.linear_model import LinearRegression
regressor = LinearRegression()
X_train = X_train.reshape(1,-1)
y_train = y_train.reshape(1,-1)
regressor.fit(X_train, y_train)
然后预测测试集结果:
# Predicting the Test Set Results
X_test = X_test.reshape(1,-1)
y_pred = regressor.predict(X_test)
我的代码在最后一行出现上述错误:
y_pred = regressor.predict(X_test)
任何正确方向的提示都会很棒。
这是我的整个代码示例:
# Simple Linear Regression
# Import Libraries
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
# Import dataset
dataset = pd.read_csv('NBA.csv')
X = dataset.iloc[:, 1].values
y = dataset.iloc[:, :-1].values
# Splitting the dataset into Train and Test
from sklearn.cross_validation import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 0)
# Feature Scaling
# None
# Fit Simple Linear Regression to Training Set
from sklearn.linear_model import LinearRegression
regressor = LinearRegression()
X_train = X_train.reshape(1,-1)
y_train = y_train.reshape(1,-1)
regressor.fit(X_train, y_train)
# Predicting the Test Set Results
X_test = X_test.reshape(1,-1)
y_pred = regressor.predict(X_test)
** 编辑 ** 我检查了 X 和 y 的形状。下面是我的输出:
dataset = pd.read_csv('NBA.csv')
X = dataset.iloc[:, 1].values
y = dataset.iloc[:, :-1].values
print(X.shape)
print(y.shape)
-->(18,)
-->(18, 1)
【问题讨论】:
-
reshape(1,-1)似乎很危险,因为它会转换“1 个样本”x“n 个特征”数组。你能展示X_train、X_test等的原始形状吗? -
@dkato 谢谢。我已经用我的原始代码更新了我的帖子。你想让我也发布我的数据吗?
-
谢谢。我想在从
datasetP数据框制作切片后,通过执行X.shape和y.shape检查X和y的形状。也许调试就足够了。 -
@dkato 好的,X 和 y 的形状分别是 (18,) 和 (18, 1)。另外,我在原始帖子的底部对其进行了更新。
-
那么在这种情况下,您是否尝试使用每个样本只有一个特征的 18 个样本来训练回归模型?
标签: python python-3.x machine-learning scikit-learn