gplvm_vfe_examples.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:geepee 作者: thangbui 项目源码 文件源码
def run_pinwheel():
    def make_pinwheel(radial_std, tangential_std, num_classes, num_per_class, rate,
                      rs=np.random.RandomState(0)):
        """Based on code by Ryan P. Adams."""
        rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False)

        features = rs.randn(num_classes * num_per_class, 2) \
            * np.array([radial_std, tangential_std])
        features[:, 0] += 1
        labels = np.repeat(np.arange(num_classes), num_per_class)

        angles = rads[labels] + rate * np.exp(features[:, 0])
        rotations = np.stack([np.cos(angles), -np.sin(angles),
                              np.sin(angles), np.cos(angles)])
        rotations = np.reshape(rotations.T, (-1, 2, 2))

        return np.einsum('ti,tij->tj', features, rotations)

    # create dataset
    print "creating dataset..."
    Y = make_pinwheel(radial_std=0.3, tangential_std=0.05, num_classes=3,
                      num_per_class=50, rate=0.4)

    # inference
    print "inference ..."
    M = 20
    D = 2
    lvm = vfe.SGPLVM(Y, D, M, lik='Gaussian')
    lvm.optimise(method='L-BFGS-B')

    mx, vx = lvm.get_posterior_x()

    fig = plt.figure()
    ax = fig.add_subplot(121)
    ax.plot(Y[:, 0], Y[:, 1], 'bx')
    ax = fig.add_subplot(122)
    ax.errorbar(mx[:, 0], mx[:, 1], xerr=np.sqrt(
        vx[:, 0]), yerr=np.sqrt(vx[:, 1]), fmt='xk')
    plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号