trainer.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:R-net 作者: matthew-z 项目源码 文件源码
def _forward(self, batch):

        _, questions, passages, answers, _ = batch
        batch_num = questions.tensor.size(0)

        questions.variable()
        passages.variable()

        begin_, end_ = self.model(questions, passages)  # batch x seq
        assert begin_.size(0) == batch_num

        answers = Variable(answers)
        if torch.cuda.is_available():
            answers = answers.cuda()
        begin, end = answers[:, 0], answers[:, 1]
        loss = self.loss_fn(begin_, begin) + self.loss_fn(end_, end)

        _, pred_begin = torch.max(begin_, 1)
        _, pred_end = torch.max(end_, 1)

        exact_correct_num = torch.sum(
            (pred_begin == begin) * (pred_end == end))
        em = exact_correct_num.data[0] / batch_num

        return loss, em
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号