def validate(val_loader, model, criterion, evaluation, logger=None):
losses = AverageMeter()
accuracies = AverageMeter()
# switch to evaluate mode
model.eval()
for i, (g, h, e, target) in enumerate(val_loader):
# Prepare input data
target = torch.squeeze(target).type(torch.LongTensor)
if args.cuda:
g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda()
g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target)
# Compute output
output = model(g, h, e)
# Logs
test_loss = criterion(output, target)
acc = Variable(evaluation(output.data, target.data, topk=(1,))[0])
losses.update(test_loss.data[0], g.size(0))
accuracies.update(acc.data[0], g.size(0))
print(' * Average Accuracy {acc.avg:.3f}; Average Loss {loss.avg:.3f}'
.format(acc=accuracies, loss=losses))
if logger is not None:
logger.log_value('test_epoch_loss', losses.avg)
logger.log_value('test_epoch_accuracy', accuracies.avg)
return accuracies.avg
评论列表
文章目录