【问题标题】:Support Vector Machine: Python Error Message支持向量机:Python 错误消息
【发布时间】:2019-01-26 00:28:12
【问题描述】:

我对机器学习比较陌生,我决定深入研究一些理论,然后用一些代码进行练习。在此过程中,我收到了许多我设法修复的错误消息,但我对此一无所知。我对 Python 也比较陌生,所以我确定这是一些与语法相关的问题,但这次我无法确定它(Python 2.7.15)。这是完整的代码:

import numpy as np
from matplotlib import pyplot as plt

# Next we input our data of the for [X, Y, Bias] in a matrix using the Numpy array method:

X = np.array([ 
    [-2, 4,-1], 
    [2, -2, -1],
    [2, 4, -1],
    [8,-4, -1],
    [9, 4, -1],
])

# Let's make another variable Y that contains the output labels for each element in the matrix:

Y = np.array([-1,-1,1,1,1])

#Now let's plot our data. We're going to use a For Loop for this:

for index,element in enumerate(X):
    if index<2:
        plt.scatter(element[0],element[1], marker="_", s=120, color="r")
    else:
        plt.scatter(element[0],element[1], marker="+", s=120, color="b")

plt.plot([-2,8], [8,0.5])
plt.show()


def svm_sgd_plot(X, Y):
    #Initialize our SVMs weight vector with zeros (3 values)
    w = np.zeros(len(X[0]))
    #The learning rate
    eta = 1
    #how many iterations to train for
    epochs = 100000
    #store misclassifications so we can plot how they change over time
    errors = []

    #training part & gradient descent part
    for epoch in range(1,epochs):
        error = 0
        for i, x in enumerate(X):
            #misclassification
            if (Y[i]*np.dot(X[i], w)) < 1:
                #misclassified update for ours weights
                w = w + eta * ( (X[i] * Y[i]) + (-2  * (1/epoch) * w) )
                error = 1
            else:
                #correct classification, update our weights
                w = w + eta * (-2  * (1/epoch) * w)
    errors.append(error)

    # lets plot the rate of classification errors during training for our SVM

    plt.plot(errors, '|')
    plt.ylim(0.5,1.5)
    plt.axes().set_yticklabels([])
    plt.xlabel('Epoch')
    plt.ylabel('Misclassified')
    plt.show()

    return w

for d, sample in enumerate(X):
        # Plot the negative samples
    if d < 2:
        plt.scatter(sample[0], sample[1], s=120, marker='_', linewidths=2)
        # Plot the positive samples
    else:
        plt.scatter(sample[0], sample[1], s=120, marker='+', linewidths=2)

# Add our test samples
plt.scatter(2,2, s=120, marker='_', linewidths=2, color='yellow')
plt.scatter(4,3, s=120, marker='+', linewidths=2, color='blue')
plt.show()

# Print the hyperplane calculated by svm_sgd()
x2=[ w[0],w[1],-w[1],w[0] ]
x3=[ w[0],w[1],w[1],-w[0] ]

x2x3 = np.array([x2,x3])
X,Y,U,V = zip(*x2x3)
ax = plt.gca()
ax.quiver(X,Y,U,V,scale=1, color='blue')

w = svm_sgd_plot(X,Y)

但我不断收到以下错误:

Traceback(最近一次调用最后一次):文件“C:\Users...\Support Vector 机器(从头开始).py”,第 134 行,在 x2=[ w[0],w[1],-w[1],w[0] ] NameError: name 'w' is not defined

我希望有更多知识的人能提供帮助。谢谢。

【问题讨论】:

  • 我喜欢你的头像。当您尚未定义变量时,还会出现 NameError。查看您的代码以找到您定义 w 的位置。
  • @Scott 哈哈谢谢。好吧,我以为我在定义函数 svm_sgd_plot(): as "w = np.zeros(len(X[0]))" 时已经定义了它,但显然这不起作用,这让我很困惑,因为它以前一直有效在其他程序中。
  • 你希望这条线做什么?你还没有定义任何名为 w yer 的东西。看起来您最终创建的w 的定义取决于这些x2x3 值,因此它们不能依赖于这些值。也许您打算在这里使用其他变量而不是w?但如果是这样,我不知道是什么变量——但大概你知道。
  • 附注,如果你在 2018 年刚刚学习 Python,为什么要学习 2.7 而不是 3.7?

标签: python python-2.7 machine-learning svm


【解决方案1】:

首先,您在方法svm_sgd_plot 中定义了w,但是在您明确调用它来做某事之前,该方法不会做任何事情。

您可以通过添加w = svm_sgd_plot(X,Y) 行来调用它,例如在绘制您的测试数据之后,这样您的代码就变成了

#PLOT TRAINING DATA

for d, sample in enumerate(X):
    # Plot the negative samples
if d < 2:
    plt.scatter(sample[0], sample[1], s=120, marker='_', linewidths=2)
    # Plot the positive samples
else:
    plt.scatter(sample[0], sample[1], s=120, marker='+', linewidths=2)

#PLOT TESTING DATA

# Add our test samples
plt.scatter(2,2, s=120, marker='_', linewidths=2, color='yellow')
plt.scatter(4,3, s=120, marker='+', linewidths=2, color='blue')
plt.show()

#CALL YOUR METHOD
w = svm_sgd_plot(X,Y)

然后你只需要可视化你的方法提供的分类。我添加了您的两个测试数据观察结果,以便您可以看到您的 SVM 方法如何正确分类它们。请注意,黄点和蓝点由您的 SVM 方法生成的线隔开。

# Print the hyperplane calculated by svm_sgd()
x2=[ w[0],w[1],-w[1],w[0] ]
x3=[ w[0],w[1],w[1],-w[0] ]

x2x3 = np.array([x2,x3])
X,Y,U,V = zip(*x2x3)
ax = plt.gca()
ax.quiver(X,Y,U,V,scale=1, color='blue')
#I ADDED THE FOLLOWING THREE LINES SO THAT YOU CAN SEE HOW YOU TESTING DATA IS BEING CLASSIFIED BY YOUR SVM METHOD
plt.scatter(2,2, s=120, marker='_', linewidths=2, color='yellow')
plt.scatter(4,3, s=120, marker='+', linewidths=2, color='blue')
plt.show()

【讨论】:

  • 哦,现在我明白了。谢谢!
猜你喜欢
  • 2017-05-30
  • 2018-06-15
  • 1970-01-01
  • 1970-01-01
  • 2019-01-27
  • 1970-01-01
  • 2014-05-12
  • 2019-09-19
  • 2012-01-31
相关资源
最近更新 更多