modellib.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:rec-attend-public 作者: renmengye 项目源码 文件源码
def f_match_loss(y_out, y_gt, match, timespan, loss_fn, model=None):
  """Binary cross entropy with matching.
  Args:
    y_out: [B, N, H, W] or [B, N, D]
    y_gt: [B, N, H, W] or [B, N, D]
    match: [B, N, N]
    match_count: [B]
    timespan: N
    loss_fn: 
  """
  # N * [B, 1, H, W]
  y_out_list = tf.split(1, timespan, y_out)
  # N * [B, 1, N]
  match_list = tf.split(1, timespan, match)
  err_list = [None] * timespan
  shape = tf.shape(y_out)
  num_ex = tf.to_float(shape[0])
  num_dim = tf.to_float(tf.reduce_prod(tf.to_float(shape[2:])))
  sshape = tf.size(shape)

  # [B, N, M] => [B, N]
  match_sum = tf.reduce_sum(match, reduction_indices=[2])
  # [B, N] => [B]
  match_count = tf.reduce_sum(match_sum, reduction_indices=[1])
  match_count = tf.maximum(match_count, 1)

  for ii in range(timespan):
    # [B, 1, H, W] * [B, N, H, W] => [B, N, H, W] => [B, N]
    # [B, N] * [B, N] => [B]
    # [B] => [B, 1]
    red_idx = tf.range(2, sshape)
    err_list[ii] = tf.expand_dims(
        tf.reduce_sum(
            tf.reduce_sum(loss_fn(y_out_list[ii], y_gt), red_idx) *
            tf.reshape(match_list[ii], [-1, timespan]), [1]), 1)

  # N * [B, 1] => [B, N] => [B]
  err_total = tf.reduce_sum(tf.concat(1, err_list), reduction_indices=[1])

  return tf.reduce_sum(err_total / match_count) / num_ex / num_dim
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号