def eval(self):
self.model.eval()
pred_result = {}
for _, batch in enumerate(self.dataloader_dev):
question_ids, questions, passages, passage_tokenized = batch
questions.variable(volatile=True)
passages.variable(volatile=True)
begin_, end_ = self.model(questions, passages) # batch x seq
_, pred_begin = torch.max(begin_, 1)
_, pred_end = torch.max(end_, 1)
pred = torch.stack([pred_begin, pred_end], dim=1)
for i, (begin, end) in enumerate(pred.cpu().data.numpy()):
ans = passage_tokenized[i][begin:end + 1]
qid = question_ids[i]
pred_result[qid] = " ".join(ans)
self.model.train()
return evaluate(self.dev_dataset, pred_result)
评论列表
文章目录