trainer.py 文件源码

python
阅读 44 收藏 0 点赞 0 评论 0

项目:R-net 作者: matthew-z 项目源码 文件源码
def eval(self):
        self.model.eval()
        pred_result = {}
        for _, batch in enumerate(self.dataloader_dev):

            question_ids, questions, passages, passage_tokenized = batch
            questions.variable(volatile=True)
            passages.variable(volatile=True)
            begin_, end_ = self.model(questions, passages)  # batch x seq

            _, pred_begin = torch.max(begin_, 1)
            _, pred_end = torch.max(end_, 1)

            pred = torch.stack([pred_begin, pred_end], dim=1)

            for i, (begin, end) in enumerate(pred.cpu().data.numpy()):
                ans = passage_tokenized[i][begin:end + 1]
                qid = question_ids[i]
                pred_result[qid] = " ".join(ans)
        self.model.train()
        return evaluate(self.dev_dataset, pred_result)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号