def plot_policy_learned(data_unpickle, color, fig_dir=None):
#recover the policy
poli = data_unpickle['policy']
#range to plot it
x = np.arange(-3,3,0.01)
means = np.zeros(np.size(x))
logstd = np.zeros(np.size(x))
for i,s in enumerate(x):
means[i] = poli.get_action(np.array((s,)))[1]['mean']
logstd[i] = poli.get_action(np.array((s,)))[1]['log_std']
# means[i] = poli.get_action(np.array([s,]))[1]['mean']
# logstd[i] = poli.get_action(np.array([s,]))[1]['log_std']
plt.plot(x, means, color=color, label = 'mean')
plt.plot(x, logstd, color=color * 0.7, label = 'logstd')
plt.legend(loc = 5)
plt.title('Final policy')
plt.xlabel('state')
plt.ylabel('Action')
if fig_dir:
plt.savefig(os.path.join(fig_dir,'policy_learned'))
else:
print("No directory for saving plots")
评论列表
文章目录