test_normal_categorical.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:cgpm 作者: probcomp 项目源码 文件源码
def test_joint(state):
    # Simulate from the joint distribution of (x,z).
    joint_samples = state.simulate(-1, [0,1], N=N_SAMPLES)
    _, ax = plt.subplots()
    ax.set_title('Joint Simulation')
    for t in INDICATORS:
        # Plot original data.
        data_subpop = DATA[DATA[:,1] == t]
        ax.scatter(data_subpop[:,1], data_subpop[:,0], color=gu.colors[t])
        # Plot simulated data for indicator t.
        samples_subpop = [j[0] for j in joint_samples if j[1] == t]
        ax.scatter(
            np.add([t]*len(samples_subpop), .25), samples_subpop,
            color=gu.colors[t])
        # KS test.
        pvalue = ks_2samp(data_subpop[:,0], samples_subpop)[1]
        assert .05 < pvalue
    ax.set_xlabel('Indicator')
    ax.set_ylabel('x')
    ax.grid()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号