【问题标题】:How to get Mean and Covariance value from pomegranate Gaussian Mixture model如何从石榴高斯混合模型中获取均值和协方差值
【发布时间】:2021-02-26 19:37:27
【问题描述】:

在 scikit 学习高斯混合模型中,我们可以通过以下方式获得均值和协方差

clf = GaussianMixture(n_components=num_clusters, covariance_type="tied", init_params='kmeans')
for i in range(clf.n_components):
    cov=clf.covariances_[i]
    mean=clf.means_[i]

但在石榴高斯混合模型的情况下,没有称为“协方差_”和“均值_”的属性 非常感谢您宝贵的时间。

【问题讨论】:

    标签: python mixture-model gmm pomegranate


    【解决方案1】:

    当您运行covariance_type="tied" 时,模型假定所有组件都有一个共同的协方差矩阵,因此上面的代码不成立。如果 covariance_type="tied" 那么它将是 clf.covariances_ 下的 1 个协方差矩阵。参考help page

    ‘full’每个分量都有自己的通用协方差矩阵

    “绑定”所有组件共享相同的通用协方差矩阵

    使用pomegranate,它估计每个组件的协方差矩阵,因此可以很好地比较从sklearn 运行GaussianMixturecovariance_type="full"

    from sklearn import datasets
    from sklearn.mixture import GaussianMixture
    
    iris = datasets.load_iris()
    
    clf = GaussianMixture(n_components=3, covariance_type="full", init_params='kmeans')
    clf.fit(iris.data)
    cov = []
    means = []
    for i in range(clf.n_components):
        cov.append(clf.covariances_[i])
        means.append(clf.means_[i])
    

    所以对于组件或集群 0:

    means[0]
    
    array([5.006, 3.428, 1.462, 0.246])
    
    cov[0]
    
    array([[0.121765, 0.097232, 0.016028, 0.010124],
           [0.097232, 0.140817, 0.011464, 0.009112],
           [0.016028, 0.011464, 0.029557, 0.005948],
           [0.010124, 0.009112, 0.005948, 0.010885]])
    

    现在使用石榴:

    from pomegranate import GeneralMixtureModel, MultivariateGaussianDistribution
    
    mdl = GeneralMixtureModel.from_samples(MultivariateGaussianDistribution,
                                           n_components=3, X=iris.data)
    mdl = mdl.fit(iris.data)
    

    参数可以在distributions下访问,只要你的组件都有一个列表。第一个,你做distributions[0],第二个distributions[1]等等:

    mdl.distributions[0].parameters[0]
    
    [5.005999999999999, 3.4280000000000004, 1.462, 0.24599999999999986]
    
    np.round(mdl.distributions[0].parameters[1],6)
    
    array([[0.121764, 0.097232, 0.016028, 0.010124],
           [0.097232, 0.140816, 0.011464, 0.009112],
           [0.016028, 0.011464, 0.029556, 0.005948],
           [0.010124, 0.009112, 0.005948, 0.010884]])
    

    【讨论】:

    猜你喜欢
    • 2014-10-25
    • 1970-01-01
    • 2018-05-07
    • 2021-10-12
    • 2016-07-25
    • 2022-01-21
    • 1970-01-01
    • 2021-09-30
    • 1970-01-01
    相关资源
    最近更新 更多