losses.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def mmd_loss(source_samples, target_samples, weight, name='mmd_loss'):
    """Adds a similarity loss term, the MMD between two representations.

    This Maximum Mean Discrepancy (MMD) loss is calculated with a number of
    different Gaussian kernels.

    Args:
      source_samples: a tensor of shape [num_samples, num_features].
      target_samples: a tensor of shape [num_samples, num_features].
      weight: the weight of the MMD loss.
      scope: optional name scope for summary tags.

    Returns:
      a scalar tensor representing the MMD loss value.
    """
    with tf.name_scope(name):
        sigmas = [
            1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100,
            1e3, 1e4, 1e5, 1e6
        ]
        gaussian_kernel = partial(
            util.gaussian_kernel_matrix, sigmas=tf.constant(sigmas))

        loss_value = maximum_mean_discrepancy(
            source_samples, target_samples, kernel=gaussian_kernel)
        loss_value = tf.maximum(1e-4, loss_value) * weight
    assert_op = tf.Assert(tf.is_finite(loss_value), [loss_value])
    with tf.control_dependencies([assert_op]):
        tag = 'MMD Loss'
        barrier = tf.no_op(tag)
    return loss_value
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号