def center_loss(features, label, alpha, num_classes):
"""Center loss based on the paper "A Discriminative Feature Learning Approach for Deep Face Recognition"
(http://ydwen.github.io/papers/WenECCV16.pdf)
"""
dim_features = features.get_shape()[1]
centers = tf.get_variable('centers', [num_classes, dim_features], dtype=tf.float32,
initializer=tf.constant_initializer(0), trainable=False)
label = tf.reshape(label, [-1])
center_feats = tf.gather(centers, label)
diff = (1 - alpha) * tf.subtract(center_feats, features)
centers = tf.scatter_sub(centers, label, diff)
loss = tf.nn.l2_loss(features - center_feats)
return loss, centers
评论列表
文章目录