def plot_hist(baseline_samples, target_samples, true_x, true_y):
baseline_samples = baseline_samples.squeeze()
target_samples = target_samples.squeeze()
bmin, bmax = baseline_samples.min(), baseline_samples.max()
ax = sns.kdeplot(baseline_samples, shade=True, color=(0.6, 0.1, 0.1, 0.2))
ax = sns.kdeplot(target_samples, shade=True, color=(0.1, 0.1, 0.6, 0.2))
ax.set_xlim(bmin, bmax)
y0, y1 = ax.get_ylim()
plt.plot([true_y, true_y], [0, y1 - (y1 - y0) * 0.01], linewidth=1, color='r')
plt.title('Predictive' + (f' at {true_x:.2f}' if true_x is not None else ''))
fig = plt.gcf()
fig.set_size_inches(9, 9)
# plt.tight_layout() # pad=0.4, w_pad=0.5, h_pad=1.0)
name = utils.DATA_DIR.replace('/', '-')
# plt.tight_layout(pad=0.6)
utils.save_fig('predictive-at-point-' + name)
评论列表
文章目录