test_mvkde.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:cgpm 作者: probcomp 项目源码 文件源码
def test_univariate_two_sample(i):
    # This test ensures posterior sampling of uni/bimodal dists on R. When the
    # plot is shown, a density curve overlays the samples which is useful for
    # seeing that logpdf/simulate agree.
    N_SAMPLES = 100

    rng = gu.gen_rng(2)
    # Synthetic samples.
    samples_train = SAMPLES[i](N_SAMPLES, rng)
    samples_test = SAMPLES[i](N_SAMPLES, rng)
    # Univariate KDE.
    kde = MultivariateKde([3], None, distargs={O: {ST: [N], SA:[{}]}}, rng=rng)
    # Incorporate observations.
    for rowid, x in enumerate(samples_train):
        kde.incorporate(rowid, {3: x})
    # Run inference.
    kde.transition()
    # Generate posterior samples.
    samples_gen = [s[3] for s in kde.simulate(-1, [3], N=N_SAMPLES)]
    # Plot comparison of all train, test, and generated samples.
    fig, ax = plt.subplots()
    ax.scatter(samples_train, [0]*len(samples_train), color='b', label='Train')
    ax.scatter(samples_gen, [1]*len(samples_gen), color='r', label='KDE')
    ax.scatter(samples_test, [2]*len(samples_test), color='g', label='Test')
    # Overlay the density function.
    xs = np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], 200)
    pdfs = [kde.logpdf(-1, {3: x}) for x in xs]
    # Convert the pdfs from the range to 1 to 1.5 by rescaling.
    pdfs_plot = np.exp(pdfs)+1
    pdfs_plot = (pdfs_plot/max(pdfs_plot)) * 1.5
    ax.plot(xs, pdfs_plot, color='k')
    # Clear up some labels.
    ax.set_title('Univariate KDE Posterior versus Generator')
    ax.set_xlabel('x')
    ax.set_yticklabels([])
    # Show the plot.
    ax.grid()
    plt.close()
    # KS test
    _, p = ks_2samp(samples_test, samples_gen)
    assert .05 < p
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号