def update_stats(self, stats, tid=0):
self.e += 1
# update stats store
for k in stats.keys():
self.stats[k].add( stats[k] )
# only plot from thread 0
if self.stats == None or tid > 0:
return
# plot if its time
if(self.e >= self.next_plot):
self.next_plot = self.e + self.stats_rate
if self.ipy_clear:
from IPython import display
display.clear_output(wait=True)
fig = plt.figure(1)
fig.canvas.set_window_title("DQN Training Stats for %s"%(self.experiment))
plt.clf()
plt.subplot(2,2,1)
self.stats["tr"].plot()
plt.title("Total Reward per Episode")
plt.xlabel("Episode")
plt.ylabel("Total Reward")
plt.legend(loc=2)
plt.subplot(2,2,2)
self.stats["ft"].plot()
plt.title("Finishing Time per Episode")
plt.xlabel("Episode")
plt.ylabel("Finishing Time")
plt.legend(loc=2)
plt.subplot(2,2,3)
self.stats["maxvf"].plot2(fill_col='lightblue', label='Avg Max VF')
self.stats["minvf"].plot2(fill_col='slategrey', label='Avg Min VF')
plt.title("Value Function Outputs")
plt.xlabel("Episode")
plt.ylabel("Value Fn")
plt.legend(loc=2)
ax = plt.subplot(2,2,4)
self.stats["cost"].plot2()
plt.title("Training Loss")
plt.xlabel("Training Epoch")
plt.ylabel("Loss")
try:
# ax.set_yscale("log", nonposy='clip')
plt.tight_layout()
except:
pass
plt.show(block=False)
plt.draw()
plt.pause(0.001)
评论列表
文章目录