def _check_predictions(self, predictions, ground_truth):
"""
:param predictions: [batch_size, max_question_length]
:param ground_truth: [batch_size, max_question_length]
:return:
"""
p = np.array(predictions)
g = np.array(ground_truth)
result = np.sum(np.abs(p - g), axis=-1)
correct = 0
for idx, r in enumerate(result):
if r == 0:
correct += 1
# tqdm.write(str(p[r]))
# tqdm.write(str(g[r]))
# tqdm.write("======================================")
return correct
评论列表
文章目录