test_mvkde.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:cgpm 作者: probcomp 项目源码 文件源码
def test_bivariate_conditional_two_sample(noise):
    # This test checks joint and conditional simulation of a bivarate normal
    # with (correlation 1-noise). The most informative use is plotting but
    # there is a numerical test for the conditional distributions.
    N_SAMPLES = 100

    rng = gu.gen_rng(2)
    # Synthetic samples.
    linear = Linear(outputs=[0,1], noise=noise, rng=rng)
    samples_train = np.asarray(
        [[s[0], s[1]] for s in linear.simulate(-1, [0,1], N=N_SAMPLES)])
    # Bivariate KDE.
    kde = MultivariateKde(
        [0,1], None, distargs={O: {ST: [N,N], SA:[{},{}]}}, rng=rng)
    # Incorporate observations.
    for rowid, x in enumerate(samples_train):
        kde.incorporate(rowid, {0: x[0], 1: x[1]})
    # Run inference.
    kde.transition()
    # Generate posterior samples from the joint.
    samples_gen = np.asarray(
        [[s[0],s[1]] for s in kde.simulate(-1, [0,1], N=N_SAMPLES)])
    # Plot comparisons of the joint.
    fig, ax = plt.subplots(nrows=1, ncols=2)
    plot_data = zip(
        ax, ['b', 'r'], ['Train', 'KDE'], [samples_train, samples_gen])
    for (a, c, l, s) in plot_data:
        a.scatter(s[:,0], s[:,1], color=c, label=l)
        a.grid()
        a.legend(framealpha=0)
    # Generate posterior samples from the conditional.
    xs = np.linspace(-3, 3, 100)
    cond_samples_a = np.asarray(
        [[s[1] for s in linear.simulate(-1, [1], {0: x0}, N=N_SAMPLES)]
        for x0 in xs])
    cond_samples_b = np.asarray(
        [[s[1] for s in kde.simulate(-1, [1], {0: x0}, N=N_SAMPLES)]
        for x0 in xs])
    # Plot the mean value on the same plots.
    for (a, s) in zip(ax, [cond_samples_a, cond_samples_b]):
        a.plot(xs, np.mean(s, axis=1), linewidth=3, color='g')
        a.set_xlim([-5,4])
        a.set_ylim([-5,4])
    plt.close('all')
    # Perform a two sample test on the means.
    mean_a = np.mean(cond_samples_a, axis=1)
    mean_b = np.mean(cond_samples_b, axis=1)
    _, p = ks_2samp(mean_a, mean_b)
    assert .01 < p
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号