def center_loss(features, label, label_stats, centers, alfa):
"""The center loss.
features: [batch_size, 512], the embedding of images.
label: [batch_size, class_num], class label, the label index is 1, others are 0.
labels_stats: [batch_size, 1], the count of each label in the batch.
centers: [class_num, 512], center points, each class have one.
alfa: float, updating rate of centers.
"""
label = tf.arg_max(label, 1)
label = tf.reshape(label, [-1])
centers_batch = tf.gather(centers, label)
diff = alfa * (centers_batch - features)
diff = diff / label_stats
centers = tf.scatter_sub(centers, label, diff)
loss = tf.nn.l2_loss(features - centers_batch)
return loss, centers
评论列表
文章目录