losses.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def difference_loss(private_samples, shared_samples, weight=1.0, name='difference_loss'):
    """Adds the difference loss between the private and shared representations.

    Args:
      private_samples: a tensor of shape [num_samples, num_features].
      shared_samples: a tensor of shape [num_samples, num_features].
      weight: the weight of the incoherence loss.
      name: the name of the tf summary.
    """
    with tf.name_scope(name):
        private_samples -= tf.reduce_mean(private_samples, 0)
        shared_samples -= tf.reduce_mean(shared_samples, 0)

        private_samples = tf.nn.l2_normalize(private_samples, 1)
        shared_samples = tf.nn.l2_normalize(shared_samples, 1)

        correlation_matrix = tf.matmul(
            private_samples, shared_samples, transpose_a=True)

        cost = tf.reduce_mean(tf.square(correlation_matrix)) * weight
        cost = tf.where(cost > 0, cost, 0, name='value')

    assert_op = tf.Assert(tf.is_finite(cost), [cost])
    with tf.control_dependencies([assert_op]):
        barrier = tf.no_op(name)
    return cost
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号