gplvm_aep_examples.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:geepee 作者: thangbui 项目源码 文件源码
def run_cluster_MC():
    import GPy
    # create dataset
    print "creating dataset..."
    N = 100
    k1 = GPy.kern.RBF(5, variance=1, lengthscale=1. /
                      np.random.dirichlet(np.r_[10, 10, 10, 0.1, 0.1]), ARD=True)
    k2 = GPy.kern.RBF(5, variance=1, lengthscale=1. /
                      np.random.dirichlet(np.r_[10, 0.1, 10, 0.1, 10]), ARD=True)
    k3 = GPy.kern.RBF(5, variance=1, lengthscale=1. /
                      np.random.dirichlet(np.r_[0.1, 0.1, 10, 10, 10]), ARD=True)
    X = np.random.normal(0, 1, (N, 5))
    A = np.random.multivariate_normal(np.zeros(N), k1.K(X), 10).T
    B = np.random.multivariate_normal(np.zeros(N), k2.K(X), 10).T
    C = np.random.multivariate_normal(np.zeros(N), k3.K(X), 10).T

    Y = np.vstack((A, B, C))
    labels = np.hstack((np.zeros(A.shape[0]), np.ones(
        B.shape[0]), np.ones(C.shape[0]) * 2))

    # inference
    print "inference ..."
    M = 30
    D = 5
    alpha = 0.5
    lvm = aep.SGPLVM(Y, D, M, lik='Gaussian')
    lvm.optimise(method='adam', adam_lr=0.05, maxiter=2000,
                 alpha=alpha, prop_mode=config.PROP_MC)

    ls = np.exp(lvm.sgp_layer.ls)
    print ls
    inds = np.argsort(ls)
    plt.figure()
    mx, vx = lvm.get_posterior_x()
    plt.scatter(mx[:, inds[0]], mx[:, inds[1]], c=labels)
    zu = lvm.sgp_layer.zu
    # plt.plot(zu[:, inds[0]], zu[:, inds[1]], 'ko')
    # plt.show()
    plt.savefig('/tmp/gplvm_cluster.pdf')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号