gmm.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:lazyprogrammer 作者: inhwane 项目源码 文件源码
def main():
    pi = np.array([0.3, 0.5, 0.2])
    mu = np.array([[1,1], [-1,-1], [-1,1]])*3
    sigma = np.array([
        [[1,0], [0,1]],
        [[2,0], [0,2]],
        [[0.5,0], [0, 0.5]],
    ])
    X, C = generate_data(pi, mu, sigma, 1000)
    plt.scatter(X[:,0], X[:,1], c=C, s=100, alpha=0.5)
    plt.show()


    # sklearn
    gmm = GMM(n_components=3, covariance_type='full')
    gmm.fit(X)
    print "pi:", gmm.weights_
    print "mu:", gmm.means_
    print "sigma:", gmm.covars_

    pi2, mu2, sigma2, L = expectation_maximization(X, len(pi))
    print "pi:", pi2
    print "mu:", mu2
    print "sigma:", sigma2
    plt.plot(L)
    plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号