def plot_samples(samples, dist, noise, modelno, num_samples, timestamp):
"""Plot the observed samples and posterior samples side-by-side."""
print 'Plotting samples %s %f' % (dist, noise)
fig, ax = plt.subplots(nrows=1, ncols=2)
fig.suptitle(
'%s (noise %1.2f, sample %d)' % (dist, noise, modelno),
size=16)
# Plot the observed samples.
T = simulate_dataset(dist, noise, num_samples)
# ax[0].set_title('Observed Data')
ax[0].text(
.5, .95, 'Observed Data',
horizontalalignment='center',
transform=ax[0].transAxes)
ax[0].set_xlabel('x1')
ax[0].set_ylabel('x2')
ax[0].scatter(T[:,0], T[:,1], color='k', alpha=.5)
ax[0].set_xlim(simulator_limits[dist][0])
ax[0].set_ylim(simulator_limits[dist][1])
ax[0].grid()
# Plot posterior distribution.
# ax[1].set_title('CrossCat Posterior Samples')
ax[1].text(
.5, .95, 'CrossCat Posterior Samples',
horizontalalignment='center',
transform=ax[1].transAxes)
ax[1].set_xlabel('x1')
clusters = set(samples[:,2])
colors = iter(matplotlib.cm.gist_rainbow(
np.linspace(0, 1, len(clusters)+2)))
for c in clusters:
sc = samples[samples[:,2] == c][:,[0,1]]
ax[1].scatter(sc[:,0], sc[:,1], alpha=.5, color=next(colors))
ax[1].set_xlim(ax[0].get_xlim())
ax[1].set_ylim(ax[0].get_ylim())
ax[1].grid()
# Save.
# fig.set_tight_layout(True)
fig.savefig(filename_samples_figure(dist, noise, modelno, timestamp))
plt.close('all')
评论列表
文章目录