qa_evaluator.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:MP-CNN-Variants 作者: tuzhucheng 项目源码 文件源码
def get_scores(self):
        self.model.eval()
        test_cross_entropy_loss = 0
        qids = []
        true_labels = []
        predictions = []

        for batch in self.data_loader:
            qids.extend(batch.id.data.cpu().numpy())
            output = self.model(batch.sentence_1, batch.sentence_2, batch.ext_feats)
            test_cross_entropy_loss += F.cross_entropy(output, batch.label, size_average=False).data[0]

            true_labels.extend(batch.label.data.cpu().numpy())
            predictions.extend(output.data.exp()[:, 1].cpu().numpy())

            del output

        qids = list(map(lambda n: int(round(n * 10, 0)) / 10, qids))

        mean_average_precision, mean_reciprocal_rank = get_map_mrr(qids, predictions, true_labels, self.data_loader.device)
        test_cross_entropy_loss /= len(batch.dataset.examples)

        return [test_cross_entropy_loss, mean_average_precision, mean_reciprocal_rank], ['cross entropy loss', 'map', 'mrr']
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号