def __call__(self, epoch):
if self._batches is None:
logger.info("Preparing evaluation data...")
self._batches = self.reader.input_module.batch_generator(self._dataset, self._batch_size, is_eval=True)
logger.info("Started evaluation %s" % self._info)
metrics = defaultdict(lambda: list())
bar = progressbar.ProgressBar(
max_value=len(self._dataset) // self._batch_size + 1,
widgets=[' [', progressbar.Timer(), '] ', progressbar.Bar(), ' (', progressbar.ETA(), ') '])
for i, batch in bar(enumerate(self._batches)):
inputs = self._dataset[i * self._batch_size:(i + 1) * self._batch_size]
predictions = self.reader.model_module(batch, self._ports)
m = self.apply_metrics(inputs, predictions)
for k in self._metrics:
metrics[k].append(m[k])
metrics = self.combine_metrics(metrics)
super().add_to_history(metrics, self._iter, epoch)
printmetrics = sorted(metrics.keys())
res = "Epoch %d\tIter %d\ttotal %d" % (epoch, self._iter, self._total)
for m in printmetrics:
res += '\t%s: %.3f' % (m, metrics[m])
self.update_summary(self._iter, self._info + '_' + m, metrics[m])
if self._write_metrics_to is not None:
with open(self._write_metrics_to, 'a') as f:
f.write("{0} {1} {2:.5}\n".format(datetime.now(), self._info + '_' + m,
np.round(metrics[m], 5)))
res += '\t' + self._info
logger.info(res)
if self._side_effect is not None:
self._side_effect_state = self._side_effect(metrics, self._side_effect_state)
评论列表
文章目录