一,介绍
算法主要步骤包括:初始化原型向量;迭代优化,更新原型向量。
流程如下:
具体来说,主要是:
1、对原型向量初始化,可以选择满足yj=tj,j∈{1,2,…,m}yj=tj,j∈{1,2,…,m}条件的某个样本 xj=(xj1,xj2,…,xjn)xj=(xj1,xj2,…,xjn)作为 qjqj的初始值;
2、从数据集DD 中任意选择一个样本 xjxj,找到与此样本距离最近的原型向量,假设为qiqi ;
3、如果xjxj的标记yjyj 与qiqi的标记titi相等,即 yj=ti,yj=ti,则令:
否则:
4、更新原型向量:
5、判断是否达到最大迭代次数或者原型向量更新幅度小于某个阈值。如果是,则停止迭代,输出原型向量;否则,转至步骤2。
其中步骤3和4的物理意义是:如果xjxj和最近的原型向量qiqi具有同样的类别标记,则令 qiqi向 xjxj的方向靠拢,且:
否则,qiqi 远离 xjxj,且
二,代码实现
import matplotlib.pyplot as plt
import numpy as np
import math
import random
def loadDataSet(filename):
fr = open(filename)
numberOfLines = len(fr.readlines())
returnMat = np.zeros((numberOfLines, 2))
classLabelVector = []
fr = open(filename)
index = 0
for line in fr.readlines():
line = line.strip().split(',')
returnMat[index, :] = line[0:2]
classLabelVector.append(line[-1])
index += 1
return returnMat, classLabelVector
# 欧几里得距离
def edistance(v1, v2):
result=0.0
for i in range(len(v1)):
result +=(v1[i]-v2[i])**2
return math.sqrt(result)
# 学习向量量化算法
def lvq(dataMat, labelMat,alpha=0.1,times=500):
classify = set(labelMat)
randinfo = [random.randint(0,14),random.randint(15,30)]
clusters = [dataMat[randinfo[i]] for i in range(len(randinfo))] # 随机选取k个值作为聚类中心
while times > 0: # 迭代次数
n = random.randint(0,29)
d=np.array([edistance(clusters[i], dataMat[n]) for i in range(len(clusters))],dtype='float') # 获取和各个聚类中心距离
index = np.argmin(d)
if(labelMat[n]==labelMat[randinfo[index]]): # 同类靠近
clusters[index]=clusters[index]+alpha*(dataMat[n]-clusters[index])
print("同类:",alpha*(dataMat[n]-clusters[index]))
else: # 异类远离
clusters[index] = clusters[index] - alpha * (dataMat[n] - clusters[index])
print("异类:", alpha * (dataMat[n] - clusters[index]))
times-=1
print("中心点:%s",(clusters))
return clusters
def plot(dataMat, labelMat,clusters):
xcord = [];ycord = []
sumx1 = 0.0;sumy1 = 0.0;sumx2 = 0.0;sumy2 = 0.0
midx = [];midy = []
for i in range(len(dataMat)):
xcord.append(float(dataMat[i][0]));ycord.append(float(dataMat[i][1]))
for i in range(len(labelMat)):
if(labelMat[i]=="1"):
plt.scatter(xcord[i], ycord[i], color='red')
else:
plt.scatter(xcord[i], ycord[i], color='black')
for c in clusters:
plt.scatter(c[0], c[1], marker='+', color='blue')
for j in range(len(labelMat)):
if (labelMat[j] == "1"):
sumx1+=xcord[j]
sumy1+=ycord[j]
else:
sumx2 += xcord[j]
sumy2 += ycord[j]
midx.append(sumx1 / 17)
midx.append(sumx2 / 17)
midy.append(sumy1 / 13)
midy.append(sumy2 / 13)
plt.scatter(midx, midy, marker='*',color='green')
plt.show()
if __name__=='__main__':
dataMat, labelMat = loadDataSet('watermelon4.1.txt')
clusters = lvq(dataMat, labelMat)
plot(dataMat, labelMat,clusters)
结果如下: