qa_utils.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:knowledgeflow 作者: 3rduncle 项目源码 文件源码
def label_ranking_average_precision_score(self, predictor, batch_size=50):
        from sklearn.metrics import label_ranking_average_precision_score 
        # ??predict
        p = []
        for xq_batch, xa_batch, _ in super(QaPairsTest, self).sampling(batch_size):
            delta = predictor(xq_batch, xa_batch)
            p += delta[0].tolist()
        p = np.array(p)
        # ???????????
        # 1. ??????????
        # 2. ??????????
        map_record = []
        skip1 = 0
        skip2 = 0
        for question, entry in self.questions.items():
            idx = np.array(entry['idx'])
            if self.y_np[idx].max() == 0:
                skip1 += 1
                continue
            if self.y_np[idx].min() != 0:
                skip2 += 1
                #continue
            score = p[idx].reshape(idx.shape).tolist()
            map = label_ranking_average_precision_score(np.array([entry['label']]), np.array([score]))
            map_record.append(map)
        logging.info('Skip1 %d Skip2 %d' % (skip1, skip2))
        return np.array(map_record).mean()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号