def test_joint(state):
# Simulate from the joint distribution of (x,z).
joint_samples = state.simulate(-1, [0,1], N=N_SAMPLES)
_, ax = plt.subplots()
ax.set_title('Joint Simulation')
for t in INDICATORS:
# Plot original data.
data_subpop = DATA[DATA[:,1] == t]
ax.scatter(data_subpop[:,1], data_subpop[:,0], color=gu.colors[t])
# Plot simulated data for indicator t.
samples_subpop = [j[0] for j in joint_samples if j[1] == t]
ax.scatter(
np.add([t]*len(samples_subpop), .25), samples_subpop,
color=gu.colors[t])
# KS test.
pvalue = ks_2samp(data_subpop[:,0], samples_subpop)[1]
assert .05 < pvalue
ax.set_xlabel('Indicator')
ax.set_ylabel('x')
ax.grid()
评论列表
文章目录