losses.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def center_loss(features, label, alpha, num_classes, name='center_loss'):
    """Center loss based on the paper "A Discriminative Feature Learning Approach for Deep Face Recognition"
       (http://ydwen.github.io/papers/WenECCV16.pdf)

    Args:
        features: 2-D `tensor` [batch_size, feature_length], input features
        label: 1-D `tensor` [batch_size], input label
        alpha: center loss parameter
        num_classes: a `int` numof classes for training

    Returns:
        a `float`, center loss
    """
    with tf.variable_scope(name):
        num_features = features.get_shape()[1]
        centers = tf.get_variable('centers', [num_classes, num_features], dtype=tf.float32,
                                  initializer=tf.constant_initializer(0), trainable=False)
        label = tf.reshape(label, [-1])
        centers_batch = tf.gather(centers, label)
        diff = (1 - alpha) * (centers_batch - features)
        centers = tf.scatter_sub(centers, label, diff)
        loss = tf.nn.l2_loss(features - centers_batch)
        return loss, centers
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号