def _forward(self, batch):
_, questions, passages, answers, _ = batch
batch_num = questions.tensor.size(0)
questions.variable()
passages.variable()
begin_, end_ = self.model(questions, passages) # batch x seq
assert begin_.size(0) == batch_num
answers = Variable(answers)
if torch.cuda.is_available():
answers = answers.cuda()
begin, end = answers[:, 0], answers[:, 1]
loss = self.loss_fn(begin_, begin) + self.loss_fn(end_, end)
_, pred_begin = torch.max(begin_, 1)
_, pred_end = torch.max(end_, 1)
exact_correct_num = torch.sum(
(pred_begin == begin) * (pred_end == end))
em = exact_correct_num.data[0] / batch_num
return loss, em
评论列表
文章目录