plt_results2D.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:snn4hrl 作者: florensacc 项目源码 文件源码
def plot_all_policy_at0(path_experiment, color, num_iter=100, fig_dir=None):
    mean_at_0 = []
    var_at_0 = []
    for itr in range(num_iter):
        data_bimodal_1d = joblib.load(os.path.join(path_experiment, 'itr_{}.pkl'.format(itr)))
        poli = data_bimodal_1d['policy']
        action_at_0 = poli.get_action(np.array((0,)))
        mean_at_0.append(action_at_0[1]['mean'])
        var_at_0.append(action_at_0[1]['log_std'])
    itr = list(range(num_iter))
    plt.plot(itr, mean_at_0, color=color, label='mean at 0')
    plt.plot(itr, var_at_0, color=color * 0.7, label='logstd at 0')
    plt.title('How the policy variates accross iterations')
    plt.xlabel('iteration')
    plt.ylabel('mean and variance at 0')
    plt.legend(loc=3)
    if fig_dir:
        plt.savefig(os.path.join(fig_dir, 'policy_at_0'))
    else:
        print("No directory for saving plots")


## plot for all the experiments
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号