def __call__(self, transformer, callback_data, phase, data, idx):
if phase == CallbackPhase.train_pre_:
self.total_iterations = callback_data['config'].attrs['total_iterations']
num_intervals = self.total_iterations // self.frequency
for loss_name in self.interval_loss_comp.output_keys:
callback_data.create_dataset("cost/{}".format(loss_name), (num_intervals,))
callback_data.create_dataset("time/loss", (num_intervals,))
elif phase == CallbackPhase.train_post:
losses = loop_eval(self.dataset, self.interval_loss_comp)
tqdm.write("Training complete. Avg losses: {}".format(losses))
elif phase == CallbackPhase.minibatch_post and ((idx + 1) % self.frequency == 0):
start_loss = default_timer()
interval_idx = idx // self.frequency
losses = loop_eval(self.dataset, self.interval_loss_comp)
for loss_name, loss in losses.items():
callback_data["cost/{}".format(loss_name)][interval_idx] = loss
callback_data["time/loss"][interval_idx] = (default_timer() - start_loss)
tqdm.write("Interval {} Iteration {} complete. Avg losses: {}".format(
interval_idx + 1, idx + 1, losses))
评论列表
文章目录