def evaluate(model: Model,
dataset: Dataset,
iterator: DataIterator,
cuda_device: int) -> Dict[str, Any]:
model.eval()
generator = iterator(dataset, num_epochs=1, cuda_device=cuda_device, for_training=False)
logger.info("Iterating over dataset")
generator_tqdm = tqdm.tqdm(generator, total=iterator.get_num_batches(dataset))
for batch in generator_tqdm:
model(**batch)
metrics = model.get_metrics()
description = ', '.join(["%s: %.2f" % (name, value) for name, value in metrics.items()]) + " ||"
generator_tqdm.set_description(description)
return model.get_metrics()
评论列表
文章目录