def evaluate_groups(true_groups, predicted):
""" Compute the AMI score and corresponding mean confidence for given gammas.
:param true_groups: (B, 1, W, H, 1)
:param predicted: (B, K, W, H, 1)
:return: scores, confidences (B,)
"""
scores, confidences = [], []
assert true_groups.ndim == predicted.ndim == 5, true_groups.shape
batch_size, K = predicted.shape[:2]
true_groups = true_groups.reshape(batch_size, -1)
predicted = predicted.reshape(batch_size, K, -1)
predicted_groups = predicted.argmax(1)
predicted_conf = predicted.max(1)
for i in range(batch_size):
true_group = true_groups[i]
idxs = np.where(true_group != 0.0)[0]
scores.append(adjusted_mutual_info_score(true_group[idxs], predicted_groups[i, idxs]))
confidences.append(np.mean(predicted_conf[i, idxs]))
return scores, confidences
评论列表
文章目录