def plot_natural_gauss(x, obs_eta1, obs_eta2, obs_loss,
title, epsilon=0.05, breaks=300):
# compute grid
eta1_grid = np.linspace(start=min(obs_eta1) - epsilon,
stop=max(obs_eta1) + epsilon,
num=breaks)
eta2_grid = np.linspace(start=min(obs_eta2) - epsilon,
stop=min(max(obs_eta2) + epsilon, 0.0),
num=breaks)
eta1_grid, eta2_grid = np.meshgrid(eta1_grid, eta2_grid)
mu_grid = get_mu(eta1_grid, eta2_grid)
sigma_grid = get_sigma(eta2_grid)
loss_grid = -np.sum(
[sp.norm(loc=mu_grid, scale=sigma_grid).logpdf(x=xi) for xi in x],
axis=0)
# plot contours and loss
fig, ax = plt.subplots(nrows=1, ncols=2)
ax[0].contour(eta1_grid, eta2_grid, loss_grid,
levels=np.linspace(np.min(loss_grid),
np.max(loss_grid),
breaks),
cmap='terrain')
ax[0].plot(obs_eta1, obs_eta2, color='red', alpha=0.5,
linestyle='dashed', linewidth=1, marker='.', markersize=3)
ax[0].set_xlabel('eta1')
ax[0].set_ylabel('eta2')
ax[1].plot(range(len(obs_loss)), obs_loss)
ax[1].set_xlabel('iter')
# ax[1].set_ylabel('loss')
plt.suptitle('{}'.format(title))
plt.show()
评论列表
文章目录