fig2.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:gumbel-relatives 作者: matejbalog 项目源码 文件源码
def main(args_dict):
    # Extract configuration from command line arguments
    Ms = np.array(args_dict['Ms'])
    alphas = np.linspace(args_dict['alpha_min'], args_dict['alpha_max'], args_dict['alpha_num'])
    K = args_dict['K']
    do_confidence = args_dict['confidence']

    # Estimate MSEs by sampling
    print('Estimating MSE of estimators of Z...')
    MSEs_Z, MSE_stdevs_Z = estimate_MSE_vs_alpha(lambda x: x, Ms, alphas, K)
    print('Estimating MSE of estimators of ln(Z)...')
    MSEs_lnZ, MSE_stdevs_lnZ = estimate_MSE_vs_alpha(np.log, Ms, alphas, K)

    # Set up plot
    matplotlib_configure_as_notebook()
    fig = plt.figure(facecolor='w', figsize=(8.25, 3.25))
    gs = gridspec.GridSpec(1, 3, width_ratios=[1.0, 1.0, 0.5])
    ax = [plt.subplot(gs[0]), plt.subplot(gs[2]), plt.subplot(gs[1])]

    ax[0].set_xlabel('$\\alpha$')
    ax[2].set_xlabel('$\\alpha$')
    ax[0].set_ylabel('MSE of estimators of $Z$, in units of $Z^2$')
    ax[2].set_ylabel('MSE of estimators of $\ln Z$, in units of $1$')

    colors = [plt.cm.plasma(0.8 - 1.0 * i / len(Ms)) for i in xrange(len(Ms))]

    # Gumbel (alpha=0) and Exponential (alpha=1) tricks can be handled analytically
    legend_Gumbel = 'Gumbel trick\n($\\alpha=0$, theoretical)'
    legend_Exponential = 'Exponential trick\n($\\alpha=1$, theoretical)'
    ax[0].scatter(np.zeros(len(Ms)), Z_Gumbel_MSE(Ms), marker='o', color=colors, label=legend_Gumbel)
    ax[0].scatter(np.ones(len(Ms)), Z_Exponential_MSE(Ms), marker='^', color=colors, label=legend_Exponential)
    ax[2].scatter(np.zeros(len(Ms)), lnZ_Gumbel_MSE(Ms), marker='o', color=colors, label=legend_Gumbel)
    ax[2].scatter(np.ones(len(Ms)), lnZ_Exponential_MSE(Ms), marker='^', color=colors, label=legend_Exponential)

    # Remaining tricks MSE were estimated by sampling
    labels = ['$M=%d$' % (M) for M in Ms]
    plot_MSEs_to_axis(ax[0], alphas, MSEs_Z, MSE_stdevs_Z, do_confidence, labels, colors)
    plot_MSEs_to_axis(ax[2], alphas, MSEs_lnZ, MSE_stdevs_lnZ, do_confidence, labels, colors)

    # Finalize plot
    ax[0].set_ylim((5*1e-3, 10))
    ax[2].set_ylim((5*1e-3, 10))
    handles, labels = ax[0].get_legend_handles_labels()
    remove_chartjunk(ax[1])
    ax[1].spines["bottom"].set_visible(False)
    ax[1].tick_params(axis="both", which="both", bottom="off", top="off", labelbottom="off", left="off", right="off", labelleft="off")
    ax[1].legend(handles, labels, frameon=False, loc='upper center', bbox_to_anchor=[0.44, 1.05])
    plt.tight_layout()
    save_plot(fig, 'figures/fig2_K%d' % (K))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号