def _batch_MAP_MRR(self,
s_label, # [batch_size, sent_num]
s_preds, # [batch_size, sent_num]
mask): # [batch_size, sent_num]
""" Calcualte the Mean Average Precision and Mean Reciprocal Rank
"""
average_precisions = []
reciprocal_ranks = []
for i in xrange(s_label.shape[0]): # For each question in the batch
# Only keep those not padded
label = np.take(s_label[i], np.where(mask[i] == 1)[0])
preds = np.take(s_preds[i], np.where(mask[i] == 1)[0])
assert(label.shape == preds.shape)
# MAP only makes sense for positive bags
try:
assert(np.max(label) > 0)
except AssertionError as e:
print(s_label)
raise e
# TODO: is this correct???
ap = label_ranking_average_precision_score([label], # true binary label
[preds]) # target scores
rr = label_ranking_reciprocal_rank(label, preds)
try: assert(not np.isnan(ap) and not np.isnan(rr))
except: pdb.set_trace()
average_precisions.append(ap)
reciprocal_ranks.append(rr)
return average_precisions, reciprocal_ranks
评论列表
文章目录