plot.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:tensorflow-mle 作者: kyleclo 项目源码 文件源码
def plot_natural_gauss(x, obs_eta1, obs_eta2, obs_loss,
                       title, epsilon=0.05, breaks=300):
    # compute grid
    eta1_grid = np.linspace(start=min(obs_eta1) - epsilon,
                            stop=max(obs_eta1) + epsilon,
                            num=breaks)
    eta2_grid = np.linspace(start=min(obs_eta2) - epsilon,
                            stop=min(max(obs_eta2) + epsilon, 0.0),
                            num=breaks)

    eta1_grid, eta2_grid = np.meshgrid(eta1_grid, eta2_grid)

    mu_grid = get_mu(eta1_grid, eta2_grid)
    sigma_grid = get_sigma(eta2_grid)

    loss_grid = -np.sum(
        [sp.norm(loc=mu_grid, scale=sigma_grid).logpdf(x=xi) for xi in x],
        axis=0)

    # plot contours and loss
    fig, ax = plt.subplots(nrows=1, ncols=2)
    ax[0].contour(eta1_grid, eta2_grid, loss_grid,
                  levels=np.linspace(np.min(loss_grid),
                                     np.max(loss_grid),
                                     breaks),
                  cmap='terrain')
    ax[0].plot(obs_eta1, obs_eta2, color='red', alpha=0.5,
               linestyle='dashed', linewidth=1, marker='.', markersize=3)
    ax[0].set_xlabel('eta1')
    ax[0].set_ylabel('eta2')
    ax[1].plot(range(len(obs_loss)), obs_loss)
    ax[1].set_xlabel('iter')
    # ax[1].set_ylabel('loss')
    plt.suptitle('{}'.format(title))
    plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号