# -*- coding: utf-8 -*-
"""
Created on Thu Nov 29 09:56:30 2018
@author: muli
"""
from sklearn.datasets.samples_generator import make_blobs
from sklearn import mixture
from sklearn.metrics import adjusted_rand_score
import matplotlib.pyplot as plt
import numpy as np
def create_data(centers,num=100,std=0.7):
'''
生成用于聚类的数据集
:param centers: 聚类的中心点组成的数组。如果中心点是二维的,则产生的每个样本都是二维的。
:param num: 样本数
:param std: 每个簇中样本的标准差
:return: 用于聚类的数据集。是一个元组,第一个元素为样本集,第二个元素为样本集的真实簇分类标记
'''
X, labels_true = make_blobs(n_samples=num, centers=centers, cluster_std=std)
return X,labels_true
def test_GMM(*data):
'''
测试 GMM 的用法
:param data: 可变参数。它是一个元组。元组元素依次为:第一个元素为样本集,第二个元素为样本集的真实簇分类标记
:return: None
'''
X,labels_true=data
clst=mixture.GaussianMixture()
clst.fit(X)
predicted_labels=clst.predict(X)
print("ARI:%s"% adjusted_rand_score(labels_true,predicted_labels))
def test_GMM_n_components(*data):
'''
测试 GMM 的聚类结果随 n_components 参数的影响
:param data: 可变参数。它是一个元组。元组元素依次为:第一个元素为样本集,第二个元素为样本集的真实簇分类标记
:return: None
'''
X,labels_true=data
nums=range(1,50)
ARIs=[]
for num in nums:
clst=mixture.GaussianMixture(n_components=num)
clst.fit(X)
predicted_labels=clst.predict(X)
ARIs.append(adjusted_rand_score(labels_true,predicted_labels))
## 绘图
fig=plt.figure()
ax=fig.add_subplot(1,1,1)
ax.plot(nums,ARIs,marker="+")
ax.set_xlabel("n_components")
ax.set_ylabel("ARI")
fig.suptitle("GMM")
# 设置 x 轴的刻度大小
plt.xticks(np.arange(1,50,2))
# 设置 X 轴的网格线,风格为 点画线
plt.grid(axis='x',linestyle='-.')
plt.show()
def test_GMM_cov_type(*data):
'''
测试 GMM 的聚类结果随协方差类型的影响
:param data: 可变参数。它是一个元组。元组元素依次为:第一个元素为样本集,第二个元素为样本集的真实簇分类标记
:return: None
'''
X,labels_true=data
nums=range(1,50)
cov_types=['spherical','tied','diag','full']
markers="+o*s"
fig=plt.figure()
ax=fig.add_subplot(1,1,1)
for i ,cov_type in enumerate(cov_types):
ARIs=[]
for num in nums:
clst=mixture.GaussianMixture(n_components=num,covariance_type=cov_type)
clst.fit(X)
predicted_labels=clst.predict(X)
ARIs.append(adjusted_rand_score(labels_true,predicted_labels))
ax.plot(nums,ARIs,marker=markers[i],label="covariance_type:%s"%cov_type)
ax.set_xlabel("n_components")
ax.legend(loc="best")
ax.set_ylabel("ARI")
fig.suptitle("GMM")
# 设置 x 轴的刻度大小
plt.xticks(np.arange(1,50,2))
# 设置 X 轴的网格线,风格为 点画线
plt.grid(axis='x',linestyle='-.')
plt.show()
if __name__=='__main__':
# 用于产生聚类的中心点
# 聚类中心是几维,则特征向量是几维的
centers=[[1,1],[2,2],[1,2],[10,20]]
# 产生用于聚类的数据集
X,labels_true=create_data(centers,1000,0.5)
# test_GMM(X,labels_true) # 调用 test_GMM 函数
# test_GMM_n_components(X,labels_true) # 调用 test_GMM_n_components 函数
test_GMM_cov_type(X,labels_true) # 调用 test_GMM_cov_type 函数
- 如图