def center_loss(features, label, alpha, num_classes, name='center_loss'):
"""Center loss based on the paper "A Discriminative Feature Learning Approach for Deep Face Recognition"
(http://ydwen.github.io/papers/WenECCV16.pdf)
Args:
features: 2-D `tensor` [batch_size, feature_length], input features
label: 1-D `tensor` [batch_size], input label
alpha: center loss parameter
num_classes: a `int` numof classes for training
Returns:
a `float`, center loss
"""
with tf.variable_scope(name):
num_features = features.get_shape()[1]
centers = tf.get_variable('centers', [num_classes, num_features], dtype=tf.float32,
initializer=tf.constant_initializer(0), trainable=False)
label = tf.reshape(label, [-1])
centers_batch = tf.gather(centers, label)
diff = (1 - alpha) * (centers_batch - features)
centers = tf.scatter_sub(centers, label, diff)
loss = tf.nn.l2_loss(features - centers_batch)
return loss, centers
评论列表
文章目录