def test(self, dataset):
self.model.eval()
self.embedding_model.eval()
loss = 0
accuracies = torch.zeros(len(dataset))
output_trees = []
outputs = []
for idx in tqdm(range(len(dataset)), desc='Testing epoch '+str(self.epoch)+''):
tree, sent, label = dataset[idx]
input = Var(sent, volatile=True)
target = Var(torch.LongTensor([int(label)]), volatile=True)
if self.args.cuda:
input = input.cuda()
target = target.cuda()
emb = F.torch.unsqueeze(self.embedding_model(input),1)
output, _, acc, tree = self.model(tree, emb)
err = self.criterion(output, target)
loss += err.data[0]
accuracies[idx] = acc
output_trees.append(tree)
outputs.append(tree.output_softmax.data.numpy())
# predictions[idx] = torch.dot(indices,torch.exp(output.data.cpu()))
return loss/len(dataset), accuracies, outputs, output_trees
评论列表
文章目录