def dann_loss(source_samples, target_samples, weight, name='dann_loss'):
"""Adds the domain adversarial (DANN) loss
Args:
source_samples: a tensor of shape [num_samples, num_features].
target_samples: a tensor of shape [num_samples, num_features].
weight: the weight of the loss.
scope: optional name scope for summary tags.
Returns:
a scalar tensor representing the correlation loss value.
"""
with tf.variable_scope(name):
batch_size = tf.shape(source_samples)[0]
samples = tf.concat(values=[source_samples, target_samples], axis=0)
samples = flatten(samples)
domain_selection_mask = tf.concat(
values=[tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))], axis=0)
grl = gradient_reverse(samples)
grl = tf.reshape(grl, (-1, samples.get_shape().as_list()[1]))
grl = fc(grl, 100, True, None, activation=relu, name='fc1')
logits = fc(grl, 1, True, None, activation=None, name='fc2')
domain_predictions = tf.sigmoid(logits)
domain_loss = tf.losses.log_loss(
domain_selection_mask, domain_predictions, weights=weight)
domain_accuracy = util.accuracy_tf(domain_selection_mask,
tf.round(domain_predictions))
assert_op = tf.Assert(tf.is_finite(domain_loss), [domain_loss])
with tf.control_dependencies([assert_op]):
tag_loss = 'losses/domain_loss'
barrier = tf.no_op(tag_loss)
return domain_loss
评论列表
文章目录