losses.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def correlation_loss(source_samples, target_samples, weight, name='corr_loss'):
    """Adds a similarity loss term, the correlation between two representations.

    Args:
        source_samples: a tensor of shape [num_samples, num_features]
        target_samples: a tensor of shape [num_samples, num_features]
        weight: a scalar weight for the loss.
        scope: optional name scope for summary tags.

    Returns:
        a scalar tensor representing the correlation loss value.
    """
    with tf.name_scope(name):
        source_samples -= tf.reduce_mean(source_samples, 0)
        target_samples -= tf.reduce_mean(target_samples, 0)
        source_samples = tf.nn.l2_normalize(source_samples, 1)
        target_samples = tf.nn.l2_normalize(target_samples, 1)
        source_cov = tf.matmul(tf.transpose(source_samples), source_samples)
        target_cov = tf.matmul(tf.transpose(target_samples), target_samples)
        corr_loss = tf.reduce_mean(
            tf.square(source_cov - target_cov)) * weight

    assert_op = tf.Assert(tf.is_finite(corr_loss), [corr_loss])
    with tf.control_dependencies([assert_op]):
        tag = 'Correlation Loss'
        barrier = tf.no_op(tag)

    return corr_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号