losses.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def dann_loss(source_samples, target_samples, weight, name='dann_loss'):
    """Adds the domain adversarial (DANN) loss

    Args:
      source_samples: a tensor of shape [num_samples, num_features].
      target_samples: a tensor of shape [num_samples, num_features].
      weight: the weight of the loss.
      scope: optional name scope for summary tags.

    Returns:
      a scalar tensor representing the correlation loss value.
    """
    with tf.variable_scope(name):
        batch_size = tf.shape(source_samples)[0]
        samples = tf.concat(values=[source_samples, target_samples], axis=0)
        samples = flatten(samples)

        domain_selection_mask = tf.concat(
            values=[tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))], axis=0)

        grl = gradient_reverse(samples)
        grl = tf.reshape(grl, (-1, samples.get_shape().as_list()[1]))

        grl = fc(grl, 100, True, None, activation=relu, name='fc1')
        logits = fc(grl, 1, True, None, activation=None, name='fc2')

        domain_predictions = tf.sigmoid(logits)

    domain_loss = tf.losses.log_loss(
        domain_selection_mask, domain_predictions, weights=weight)

    domain_accuracy = util.accuracy_tf(domain_selection_mask,
                                       tf.round(domain_predictions))

    assert_op = tf.Assert(tf.is_finite(domain_loss), [domain_loss])
    with tf.control_dependencies([assert_op]):
        tag_loss = 'losses/domain_loss'
        barrier = tf.no_op(tag_loss)

    return domain_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号