def f_match_loss(y_out, y_gt, match, timespan, loss_fn, model=None):
"""Binary cross entropy with matching.
Args:
y_out: [B, N, H, W] or [B, N, D]
y_gt: [B, N, H, W] or [B, N, D]
match: [B, N, N]
match_count: [B]
timespan: N
loss_fn:
"""
# N * [B, 1, H, W]
y_out_list = tf.split(1, timespan, y_out)
# N * [B, 1, N]
match_list = tf.split(1, timespan, match)
err_list = [None] * timespan
shape = tf.shape(y_out)
num_ex = tf.to_float(shape[0])
num_dim = tf.to_float(tf.reduce_prod(tf.to_float(shape[2:])))
sshape = tf.size(shape)
# [B, N, M] => [B, N]
match_sum = tf.reduce_sum(match, reduction_indices=[2])
# [B, N] => [B]
match_count = tf.reduce_sum(match_sum, reduction_indices=[1])
match_count = tf.maximum(match_count, 1)
for ii in range(timespan):
# [B, 1, H, W] * [B, N, H, W] => [B, N, H, W] => [B, N]
# [B, N] * [B, N] => [B]
# [B] => [B, 1]
red_idx = tf.range(2, sshape)
err_list[ii] = tf.expand_dims(
tf.reduce_sum(
tf.reduce_sum(loss_fn(y_out_list[ii], y_gt), red_idx) *
tf.reshape(match_list[ii], [-1, timespan]), [1]), 1)
# N * [B, 1] => [B, N] => [B]
err_total = tf.reduce_sum(tf.concat(1, err_list), reduction_indices=[1])
return tf.reduce_sum(err_total / match_count) / num_ex / num_dim
评论列表
文章目录