def difference_loss(private_samples, shared_samples, weight=1.0, name='difference_loss'):
"""Adds the difference loss between the private and shared representations.
Args:
private_samples: a tensor of shape [num_samples, num_features].
shared_samples: a tensor of shape [num_samples, num_features].
weight: the weight of the incoherence loss.
name: the name of the tf summary.
"""
with tf.name_scope(name):
private_samples -= tf.reduce_mean(private_samples, 0)
shared_samples -= tf.reduce_mean(shared_samples, 0)
private_samples = tf.nn.l2_normalize(private_samples, 1)
shared_samples = tf.nn.l2_normalize(shared_samples, 1)
correlation_matrix = tf.matmul(
private_samples, shared_samples, transpose_a=True)
cost = tf.reduce_mean(tf.square(correlation_matrix)) * weight
cost = tf.where(cost > 0, cost, 0, name='value')
assert_op = tf.Assert(tf.is_finite(cost), [cost])
with tf.control_dependencies([assert_op]):
barrier = tf.no_op(name)
return cost
评论列表
文章目录