qa_utils.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:knowledgeflow 作者: 3rduncle 项目源码 文件源码
def label_ranking_average_precision_score2(self, model, batch_size=50): 
        def label_ranking_average_precision_score(label, score):
            assert len(label) == len(score)
            data = zip(label, score)
            data = sorted(data, key=lambda x:x[1],reverse=True)
            count = 0.0
            values = []
            for i in range(len(data)):
                if data[i][0]:
                    count += 1
                    values.append(count / (i + 1))
            assert len(values)
            return sum(values) / count, values[0]
        p = model.predict(
            {'q_input': self.xq_np, 'a_input':self.xa_np},
            batch_size=batch_size
        )
        map_record = []
        for question, entry in self.questions.items():
            idx = np.array(entry['idx'])
            if self.y_np[idx].max() == 0:
                continue
            score = p[idx].reshape(idx.shape).tolist()
            map, _ = label_ranking_average_precision_score(entry['label'], score)
            map_record.append(map)
            self.saveResult(question, map, score)
        map = np.array(map_record).mean()
        self.saveResult('__TOTAL_MAP__', map)
        return map
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号