【问题标题】:svm confusion_matrix 'too many values to unpack'支持向量机混淆矩阵'解包的值太多'
【发布时间】:2019-07-12 13:06:58
【问题描述】:

我已经训练了一个 SVM 模型并尝试创建一个混淆矩阵来评估它。 因此,我用测试数据进行预测,并将预测与测试数据的目标类进行比较。

我有大约1000条数据记录,而Test数据是近300条数据记录。 我定义了九个类/标签。

特征从-1到1归一化,都是float类型。 数组A的一行代表每条数据记录,目标类存放在数组B中。 我将这些数组按 70:30 的比例分成训练和测试数据。

这是一个简单的代码,但我现在不知道。 一种可能性是对测试数据的每个数据记录进行预测和混淆矩阵,并将结果存储在列表中。遍历所有数据记录后,我可以建立所有存储元素的平均值吗?

有人知道如何解决我的问题吗?

# -*- coding: utf-8 -*-
"""
Created on Fri Apr  5 10:50:47 2019

@author: mattdoe
"""

from data_preprocessor_db import data_storage # validation data
from sklearn.preprocessing import MinMaxScaler
from sklearn import svm
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from numpy import array
import pickle



# for seperation of data_storage
# Link_ID = list()
Input, Output = list(), list()

# seperate data_storage in Input and Output data
for items in data_storage:
    # Link_ID = items[0] # identifier not needed
    Input.append((float(items[1]), float(items[2]), float(items[3]), float(items[4]), float(items[5]), float(items[6]), float(items[7]), float(items[8]), float(items[9]))) # Input: all characteristics
    Output.append(float(items[10])) # Output: scenario_class 1 to 9

# Input tuple to array
A = array(Input)

# normalise array between 0 and 1
scaler = MinMaxScaler(feature_range=(-1, 1))
scaledA = scaler.fit_transform(A)

# Output tuple to array
B = array(Output)

# split train and test data; ration: 70:30
# shuffle = False: doesn't sort data randomly
# shuffle = True: default: sorts data randomly
A_train, A_test, B_train, B_test = train_test_split(A, B, test_size=0.3, shuffle=True, random_state=40)

# create model
model = svm.SVC(kernel='linear', C = 1.0)

# fit model
model.fit(A_train, B_train)

# get support vectors
# model.support_vectors_

# get indices of support vectors
# model.support_

# get number of support vectors for each class
# model.n_support_

filename = 'ml_svm.sav'
pickle.dump(model, open(filename, 'wb'))

# load the model from disk
loaded_model = pickle.load(open(filename, 'rb'))

# test to all data records
# result = loaded_model.score(A, B)

# test with test data
# score represents the mean accuracy of given test data and labels
result = loaded_model.score(A_test, B_test) # relative 
print(result)

# confusion matrix compares true value with predicted value
# true value <--> predicted value
predicted = model.predict(A_test)
tn, fp, fn, tp = confusion_matrix(B_test, predicted, labels=[1, 2, 3, 4, 5, 6, 7, 8, 9]).ravel()

我的错误:

Traceback (most recent call last):

  File "<ipython-input-8-8649dd873bbd>", line 1, in <module>
    runfile('C:/Workspace/Master-Thesis/Programm/MapValidationML/ml_svm.py', wdir='C:/Workspace/Master-Thesis/Programm/MapValidationML')

  File "C:\ProgramData\Anaconda3\lib\site-packages\spyder_kernels\customize\spydercustomize.py", line 827, in runfile
    execfile(filename, namespace)

  File "C:\ProgramData\Anaconda3\lib\site-packages\spyder_kernels\customize\spydercustomize.py", line 110, in execfile
    exec(compile(f.read(), filename, 'exec'), namespace)

  File "C:/Workspace/Master-Thesis/Programm/MapValidationML/ml_svm.py", line 75, in <module>
    tn, fp, fn, tp = confusion_matrix(B_test, predicted, labels=[1, 2, 3, 4, 5, 6, 7, 8, 9]).ravel()

ValueError: too many values to unpack (expected 4)

【问题讨论】:

  • confusion_matrix 正在返回矩阵。使用 ravel 将矩阵重塑为向量。向量的维度取决于矩阵大小(所以在他们为binary case 显示的文档中,它们有4个元素),但在你的情况下,矩阵维度取决于我猜的标签(很可能是NxN标签)。因此,如果您事先不知道矩阵大小,可能会将结果存储在向量中(所以 confusion_vector = confusion_matrix(.. 而不是 tn, fp, fn, tp = confusion_matrix(..
  • 我认为你应该把它放在答案@elgordorafiki

标签: python machine-learning scikit-learn svm


【解决方案1】:

感谢 elgordorafiki。 使用混淆向量 = 混淆矩阵(...) 的解决方案效果很好。

没有 .ravel(),我现在收到一个 9x9 矩阵。

对角线上的结果是否都是正确的值,而对角线上的结果是否都是不正确的?那么每一列每一行代表一个类?哪些是预测类?列还是行?

我必须如何理解结果?

我的结果如下:

[[ 35   1   0   0   0   0   0   0   0]
 [  0 177   0   0   0   0   0   0   0]
 [  3   2   0   0   0   0   0   0   0]
 [  2   3   0   0   0   0   0   0   0]
 [  0   0   0   0   5   0   0   0   0]
 [  0   0   0   0   0   8   0   0   0]
 [  0   0   0   0   0   0   3   0   0]
 [  0   0   0   0   0   0   0   7   0]
 [  4   6   0   0   1   1   1   0  14]]

在我的情况下,3 级和 4 级似乎与 1 级和 2 级有问题。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2019-09-14
    • 2018-04-07
    • 2019-02-12
    • 1970-01-01
    • 2018-09-18
    • 1970-01-01
    • 2018-10-05
    • 2018-11-19
    相关资源
    最近更新 更多