plot.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:rl_algorithms 作者: DanielTakeshi 项目源码 文件源码
def plot_one_dir(args, directory):
    """ The actual plotting code.

    Assumes that we'll be plotting from one directory, which usually means
    considering one random seed only, however it's better to have multiple
    random seeds so this code generalizes. For ES, we should store the output at
    *every* timestep, so A['TotalIterations'] should be like np.arange(...), but
    this generalizes in case Ray can help me run for many more iterations.
    """
    print("Now plotting based on directory {} ...".format(directory))

    ### Figure 1: The log.txt file.
    num = len(ATTRIBUTES)
    fig, axes = subplots(num, figsize=(12,3*num))
    for (dd, cc) in zip(directory, COLORS):
        A = np.genfromtxt(join(args.expdir, dd, 'log.txt'),
                          delimiter='\t', dtype=None, names=True)
        x = A['TotalIterations']
        for (i,attr) in enumerate(ATTRIBUTES):
            axes[i].plot(x, A[attr], '-', lw=lw, color=cc, label=dd)
            axes[i].set_ylabel(attr, fontsize=ysize)
            axes[i].tick_params(axis='x', labelsize=tick_size)
            axes[i].tick_params(axis='y', labelsize=tick_size)
            axes[i].legend(loc='best', ncol=1, prop={'size':legend_size})
    plt.tight_layout()
    plt.savefig(args.out+'_log.png')

    ### Figure 2: Error regions.
    num = len(directory)
    if num == 1: 
        num+= 1
    fig, axes = subplots(1,num, figsize=(12*num,10))
    for (i, (dd, cc)) in enumerate(zip(directory, COLORS)):
        A = np.genfromtxt(join(args.expdir, dd, 'log.txt'),
                          delimiter='\t', dtype=None, names=True)
        axes[i].plot(A['TotalIterations'], A["FinalAvgReturns"], 
                     color=cc, marker='x', ms=ms, lw=lw)
        axes[i].fill_between(A['TotalIterations'],
                             A["FinalAvgReturns"] - A["FinalStdReturns"],
                             A["FinalAvgReturns"] + A["FinalStdReturns"],
                             alpha = error_region_alpha,
                             facecolor='y')
        axes[i].set_ylim(ENV_TO_YLABELS[args.envname])
        axes[i].tick_params(axis='x', labelsize=tick_size)
        axes[i].tick_params(axis='y', labelsize=tick_size)
        axes[i].set_title("Mean Episode Rewards ({})".format(dd), fontsize=title_size)
        axes[i].set_xlabel("ES Iterations", fontsize=xsize)
        axes[i].set_ylabel("Rewards", fontsize=ysize)
    plt.tight_layout()
    plt.savefig(args.out+'_rewards_std.png')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号