def test(self, dataset):
self.model.eval()
total_loss = 0
predictions = torch.zeros(len(dataset))
indices = torch.arange(1, dataset.num_classes + 1)
for idx in tqdm(range(len(dataset)),desc='Testing epoch ' + str(self.epoch) + ''):
ltree, lsent, rtree, rsent, label = dataset[idx]
linput, rinput = Var(lsent, volatile=True), Var(rsent, volatile=True)
target = Var(map_label_to_target(label, dataset.num_classes), volatile=True)
if self.args.cuda:
linput, rinput = linput.cuda(), rinput.cuda()
target = target.cuda()
output = self.model(ltree, linput, rtree, rinput)
loss = self.criterion(output, target)
total_loss += loss.data[0]
output = output.data.squeeze().cpu()
predictions[idx] = torch.dot(indices, torch.exp(output))
return total_loss / len(dataset), predictions
评论列表
文章目录