def wasserstein_disagreement_map(prediction, ground_truth, M):
"""
Function to calculate the pixel-wise Wasserstein distance between the
flattened pred_proba and the flattened labels (ground_truth) with respect
to the distance matrix on the label space M.
:param prediction: the logits after softmax
:param ground_truth: segmentation ground_truth
:param M: distance matrix on the label space
:return: the pixelwise distance map (wass_dis_map)
"""
# pixel-wise Wassertein distance (W) between flat_pred_proba and flat_labels
# wrt the distance matrix on the label space M
n_classes = K.int_shape(prediction)[-1]
# unstack_labels = tf.unstack(ground_truth, axis=-1)
ground_truth = tf.cast(ground_truth, dtype=tf.float64)
# unstack_pred = tf.unstack(prediction, axis=-1)
prediction = tf.cast(prediction, dtype=tf.float64)
# print("shape of M", M.shape, "unstacked labels", unstack_labels,
# "unstacked pred" ,unstack_pred)
# W is a weighting sum of all pairwise correlations (pred_ci x labels_cj)
pairwise_correlations = []
for i in range(n_classes):
for j in range(n_classes):
pairwise_correlations.append(
M[i, j] * tf.multiply(prediction[:,i], ground_truth[:,j]))
wass_dis_map = tf.add_n(pairwise_correlations)
return wass_dis_map
评论列表
文章目录