VisCostCallback.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:mxnet_workshop 作者: NervanaSystems 项目源码 文件源码
def _process_batch(self, param, name):
        if self.handle is None:
            self.handle = show(self.fig, notebook_handle=True)

        now = default_timer()
        # print "{}_{}".format(param.nbatch, param.epoch)

        if param.nbatch == 0:
            self.epoch = self.epoch + 1

        time = float(param.nbatch) / self.total + param.epoch

        if param.eval_metric is not None:
            name_value = param.eval_metric.get_name_value()
            param.eval_metric.reset()

            cost = name_value[0][1]

            if name == 'train':
                cost = self.get_average_cost(cost)

            if math.isnan(cost) or cost > 4000:
                cost = 4000

            if name == 'train':
                self.train_source.data['x'].append(time)
                self.train_source.data['y'].append(cost)
            elif name == 'eval':
                self.val_source.data['x'].append(param.epoch+1)
                self.val_source.data['y'].append(cost)               

            if (now - self.last_update > self.update_thresh_s):
                self.last_update = now

                if self.handle is not None:
                    push_notebook(handle=self.handle)
                else:
                    push_notebook()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号