a3c.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:kerlym 作者: osh 项目源码 文件源码
def update_stats(self, stats, tid=0):
        self.e += 1
        # update stats store
        for k in stats.keys():
            self.stats[k].add( stats[k] )

        # only plot from thread 0
        if self.stats == None or tid > 0:
            return

        # plot if its time
        if(self.e >= self.next_plot):
            self.next_plot = self.e + self.stats_rate
            if self.ipy_clear:
                from IPython import display
                display.clear_output(wait=True)
            fig = plt.figure(1)
            fig.canvas.set_window_title("A3C Training Stats for %s"%(self.experiment))
            plt.clf()
            plt.subplot(2,2,1)
            self.stats["tr"].plot()
            plt.title("Total Reward per Episode")
            plt.xlabel("Episode")
            plt.ylabel("Total Reward")
            plt.legend(loc=2)
            plt.subplot(2,2,2)
            self.stats["ft"].plot()
            plt.title("Finishing Time per Episode")
            plt.xlabel("Episode")
            plt.ylabel("Finishing Time")
            plt.legend(loc=2)
            plt.subplot(2,2,3)
            self.stats["maxvf"].plot2(fill_col='lightblue', label='Avg Max VF')
            self.stats["minvf"].plot2(fill_col='slategrey', label='Avg Min VF')
            plt.title("Value Function Outputs")
            plt.xlabel("Episode")
            plt.ylabel("Value Fn")
            plt.legend(loc=2)
            ax = plt.subplot(2,2,4)
            self.stats["cost"].plot2()
            plt.title("Training Loss")
            plt.xlabel("Training Epoch")
            plt.ylabel("Loss")
            try:
#                ax.set_yscale("log", nonposy='clip')
                plt.tight_layout()
            except:
                pass
            plt.show(block=False)
            plt.draw()
            plt.pause(0.001)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号