callbacks.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def __call__(self, transformer, callback_data, phase, data, idx):
        if phase == CallbackPhase.train_pre_:
            self.total_iterations = callback_data['config'].attrs['total_iterations']
            num_intervals = self.total_iterations // self.frequency
            for loss_name in self.interval_loss_comp.output_keys:
                callback_data.create_dataset("cost/{}".format(loss_name), (num_intervals,))
            callback_data.create_dataset("time/loss", (num_intervals,))
        elif phase == CallbackPhase.train_post:
            losses = loop_eval(self.dataset, self.interval_loss_comp)
            tqdm.write("Training complete.  Avg losses: {}".format(losses))
        elif phase == CallbackPhase.minibatch_post and ((idx + 1) % self.frequency == 0):
            start_loss = default_timer()
            interval_idx = idx // self.frequency

            losses = loop_eval(self.dataset, self.interval_loss_comp)

            for loss_name, loss in losses.items():
                callback_data["cost/{}".format(loss_name)][interval_idx] = loss

            callback_data["time/loss"][interval_idx] = (default_timer() - start_loss)
            tqdm.write("Interval {} Iteration {} complete.  Avg losses: {}".format(
                interval_idx + 1, idx + 1, losses))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号