plt_results1D.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:snn4hrl 作者: florensacc 项目源码 文件源码
def plot_policy_learned(data_unpickle, color, fig_dir=None):
    #recover the policy
    poli = data_unpickle['policy']
    #range to plot it
    x = np.arange(-3,3,0.01)
    means = np.zeros(np.size(x))
    logstd = np.zeros(np.size(x))
    for i,s in enumerate(x):
        means[i] = poli.get_action(np.array((s,)))[1]['mean']
        logstd[i] = poli.get_action(np.array((s,)))[1]['log_std']
        # means[i] = poli.get_action(np.array([s,]))[1]['mean']
        # logstd[i] = poli.get_action(np.array([s,]))[1]['log_std']

    plt.plot(x, means, color=color, label = 'mean')
    plt.plot(x, logstd, color=color * 0.7, label = 'logstd')
    plt.legend(loc = 5)
    plt.title('Final policy')
    plt.xlabel('state')
    plt.ylabel('Action')
    if fig_dir:
        plt.savefig(os.path.join(fig_dir,'policy_learned'))
    else:
        print("No directory for saving plots")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号