def listen(self, results):
score_out = results['score_out']
y_gt = results['y_gt']
sort_idx = np.argsort(score_out, axis=-1)
idx_gt = np.argmax(y_gt, axis=-1)
correct = 0
count = 0
for kk, ii in enumerate(idx_gt):
sort_idx_ = sort_idx[kk][::-1]
for jj in sort_idx_[:self.top_k]:
if ii == jj:
correct += 1
break
count += 1
# self.log.info('Correct {}/{}'.format(correct, count))
self.correct += correct
self.count += count
self.step = int(results['step'])
# self.log.info('Step {}'.format(self.step))
pass
评论列表
文章目录