gplvm_vfe_examples.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:geepee 作者: thangbui 项目源码 文件源码
def run_semicircle():
    # create dataset
    print "creating dataset..."
    N = 20
    cos_val = [0.97, 0.95, 0.94, 0.89, 0.8,
               0.88, 0.92, 0.96, 0.7, 0.65,
               0.3, 0.25, 0.1, -0.25, -0.3,
               -0.6, -0.67, -0.75, -0.97, -0.98]
    cos_val = np.array(cos_val).reshape((N, 1))
    # cos_val = 2*np.random.rand(N, 1) - 1
    angles = np.arccos(cos_val)
    sin_val = np.sin(angles)
    Y = np.hstack((sin_val, cos_val))
    Y += 0.05 * np.random.randn(Y.shape[0], Y.shape[1])

    # inference
    print "inference ..."
    M = 10
    D = 2
    lvm = vfe.SGPLVM(Y, D, M, lik='Gaussian')
    lvm.optimise(method='L-BFGS-B', maxiter=2000)
    # lvm.optimise(method='adam', maxiter=2000)

    plt.figure()
    plt.plot(Y[:, 0], Y[:, 1], 'sb')

    mx, vx = lvm.get_posterior_x()
    for i in range(mx.shape[0]):
        mxi = mx[i, :]
        vxi = vx[i, :]
        mxi1 = mxi + np.sqrt(vxi)
        mxi2 = mxi - np.sqrt(vxi)
        mxis = np.vstack([mxi.reshape((1, D)),
                          mxi1.reshape((1, D)),
                          mxi2.reshape((1, D))])
        myis, vyis = lvm.predict_f(mxis)

        plt.errorbar(myis[:, 0], myis[:, 1],
                     xerr=np.sqrt(vyis[:, 0]), yerr=np.sqrt(vyis[:, 1]), fmt='.k')

    plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号