def deploy(x, labels):
pred = m(x)
loss = crit(pred, labels)
values, bests = pred.topk(pred.size(1), dim=1)
_, ranking = bests.topk(bests.size(1), dim=1, largest=False) # [batch_size, dict_size]
rank = torch.gather(ranking.data, 1, labels.data[:, None]).cpu().numpy().squeeze()
top5_preds = bests[:, :5].cpu().data.numpy()
top1_acc = np.mean(rank==0)
top5_acc = np.mean(rank<5)
return loss.data[0], top1_acc, top5_acc
评论列表
文章目录