modellib.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:rec-attend-public 作者: renmengye 项目源码 文件源码
def f_segm_match(iou, s_gt):
  """Matching between segmentation output and groundtruth.
  Args:
    y_out: [B, T, H, W], output segmentations
    y_gt: [B, T, H, W], groundtruth segmentations
    s_gt: [B, T], groudtruth score sequence
  """
  global hungarian_module
  if hungarian_module is None:
    mod_name = './hungarian.so'
    hungarian_module = tf.load_op_library(mod_name)
    log.info('Loaded library "{}"'.format(mod_name))

  # Mask X, [B, M] => [B, 1, M]
  mask_x = tf.expand_dims(s_gt, dim=1)
  # Mask Y, [B, M] => [B, N, 1]
  mask_y = tf.expand_dims(s_gt, dim=2)
  iou_mask = iou * mask_x * mask_y

  # Keep certain precision so that we can get optimal matching within
  # reasonable time.
  eps = 1e-5
  precision = 1e6
  iou_mask = tf.round(iou_mask * precision) / precision
  match_eps = hungarian_module.hungarian(iou_mask + eps)[0]

  # [1, N, 1, 1]
  s_gt_shape = tf.shape(s_gt)
  num_segm_out = s_gt_shape[1]
  num_segm_out_mul = tf.pack([1, num_segm_out, 1])
  # Mask the graph algorithm output.
  match = match_eps * mask_x * mask_y

  return match
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号