def f_segm_match(iou, s_gt):
"""Matching between segmentation output and groundtruth.
Args:
y_out: [B, T, H, W], output segmentations
y_gt: [B, T, H, W], groundtruth segmentations
s_gt: [B, T], groudtruth score sequence
"""
global hungarian_module
if hungarian_module is None:
mod_name = './hungarian.so'
hungarian_module = tf.load_op_library(mod_name)
log.info('Loaded library "{}"'.format(mod_name))
# Mask X, [B, M] => [B, 1, M]
mask_x = tf.expand_dims(s_gt, dim=1)
# Mask Y, [B, M] => [B, N, 1]
mask_y = tf.expand_dims(s_gt, dim=2)
iou_mask = iou * mask_x * mask_y
# Keep certain precision so that we can get optimal matching within
# reasonable time.
eps = 1e-5
precision = 1e6
iou_mask = tf.round(iou_mask * precision) / precision
match_eps = hungarian_module.hungarian(iou_mask + eps)[0]
# [1, N, 1, 1]
s_gt_shape = tf.shape(s_gt)
num_segm_out = s_gt_shape[1]
num_segm_out_mul = tf.pack([1, num_segm_out, 1])
# Mask the graph algorithm output.
match = match_eps * mask_x * mask_y
return match
评论列表
文章目录