def mmd_loss(source_samples, target_samples, weight, name='mmd_loss'):
"""Adds a similarity loss term, the MMD between two representations.
This Maximum Mean Discrepancy (MMD) loss is calculated with a number of
different Gaussian kernels.
Args:
source_samples: a tensor of shape [num_samples, num_features].
target_samples: a tensor of shape [num_samples, num_features].
weight: the weight of the MMD loss.
scope: optional name scope for summary tags.
Returns:
a scalar tensor representing the MMD loss value.
"""
with tf.name_scope(name):
sigmas = [
1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100,
1e3, 1e4, 1e5, 1e6
]
gaussian_kernel = partial(
util.gaussian_kernel_matrix, sigmas=tf.constant(sigmas))
loss_value = maximum_mean_discrepancy(
source_samples, target_samples, kernel=gaussian_kernel)
loss_value = tf.maximum(1e-4, loss_value) * weight
assert_op = tf.Assert(tf.is_finite(loss_value), [loss_value])
with tf.control_dependencies([assert_op]):
tag = 'MMD Loss'
barrier = tf.no_op(tag)
return loss_value
评论列表
文章目录