acc_listener.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:tfplus 作者: renmengye 项目源码 文件源码
def listen(self, results):
        score_out = results['score_out']
        y_gt = results['y_gt']
        sort_idx = np.argsort(score_out, axis=-1)
        idx_gt = np.argmax(y_gt, axis=-1)
        correct = 0
        count = 0
        for kk, ii in enumerate(idx_gt):
            sort_idx_ = sort_idx[kk][::-1]
            for jj in sort_idx_[:self.top_k]:
                if ii == jj:
                    correct += 1
                    break
            count += 1
        # self.log.info('Correct {}/{}'.format(correct, count))
        self.correct += correct
        self.count += count
        self.step = int(results['step'])
        # self.log.info('Step {}'.format(self.step))
        pass
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号