plot.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tensorflow-mle 作者: kyleclo 项目源码 文件源码
def plot_canonical_gauss(x, obs_mu, obs_sigma, obs_loss,
                         title, epsilon=0.05, breaks=100):
    # compute grid
    mu_grid = np.linspace(start=min(obs_mu) - epsilon,
                          stop=max(obs_mu) + epsilon,
                          num=breaks)
    sigma_grid = np.linspace(start=max(min(obs_sigma) - epsilon, 0.0),
                             stop=max(obs_sigma) + epsilon,
                             num=breaks)
    mu_grid, sigma_grid = np.meshgrid(mu_grid, sigma_grid)
    loss_grid = -np.sum(
        [sp.norm(loc=mu_grid, scale=sigma_grid).logpdf(x=xi) for xi in x],
        axis=0)

    # plot contours and loss
    fig, ax = plt.subplots(nrows=1, ncols=2)
    ax[0].contour(mu_grid, sigma_grid, loss_grid,
                  levels=np.linspace(np.min(loss_grid),
                                     np.max(loss_grid),
                                     breaks),
                  cmap='terrain')
    ax[0].plot(obs_mu, obs_sigma, color='red', alpha=0.5,
               linestyle='dashed', linewidth=1, marker='.', markersize=3)
    ax[0].set_xlabel('mu')
    ax[0].set_ylabel('sigma')
    ax[1].plot(range(len(obs_loss)), obs_loss)
    ax[1].set_xlabel('iter')
    # ax[1].set_ylabel('loss')
    plt.suptitle('{}'.format(title))
    plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号