fig3b_plot.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:gumbel-relatives 作者: matejbalog 项目源码 文件源码
def main(args_dict):
    # Extract configuration from command line arguments
    MK = np.array(args_dict['MK'])
    M = 100
    K = MK / M
    print('M = %d; K = %d' % (M, K))
    x_type = args_dict['x_type']
    deltas = args_dict['deltas']
    do_confidence = args_dict['confidence']

    # Load data from JSON files generated by (non-public) Matlab code
    jsons = [json_load('data/bandits_normal_delta%s_MK%d.json' % (delta, MK)) for delta in deltas]
    lnZs = np.array([json['lnZ'] for json in jsons])
    MAPs = np.array([json['MAPs_ttest'] for json in jsons])

    # Estimate estimator MSEs for the various tricks (as specified by alphas)
    alphas = np.linspace(-0.2, 1.5, 100)
    MSEs, MSEs_stdev = MAPs_to_estimator_MSE_vs_alpha(1, MAPs, lnZs, alphas, K)

    # Set up plot
    matplotlib_configure_as_notebook()
    fig, ax = plt.subplots(1, 1, facecolor='w', figsize=(4.25, 3.25))
    ax.set_xlabel('trick parameter $\\alpha$')
    ax.set_ylabel('MSE of estimator of $\ln Z$')

    # Plot the MSEs
    labels = ['$\\delta = %g$' % (delta) for delta in deltas]
    colors = [plt.cm.plasma((np.log10(delta) - (-3)) / (0 - (-3))) for delta in deltas]
    plot_MSEs_to_axis(ax, alphas, MSEs, MSEs_stdev, do_confidence, labels, colors)

    # Finalize plot
    for vertical in [0.0, 1.0]:
        ax.axvline(vertical, color='black', linestyle='dashed', alpha=.7)
    ax.annotate('Gumbel trick', xy=(0.0, 0.0052), rotation=90, horizontalalignment='right', verticalalignment='bottom')
    ax.annotate('Exponential trick', xy=(1.0, 0.0052), rotation=90, horizontalalignment='right', verticalalignment='bottom')
    lgd = ax.legend(loc='upper center')
    ax.set_ylim((5*1e-3, 5*1e-2))
    save_plot(fig, 'figures/fig3b', bbox_extra_artists=(lgd,))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号