cnn.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:tf-re-id 作者: jhb86253817 项目源码 文件源码
def center_loss(features, label, label_stats, centers, alfa):
    """The center loss.
       features: [batch_size, 512], the embedding of images. 
       label: [batch_size, class_num], class label, the label index is 1, others are 0.
       labels_stats: [batch_size, 1], the count of each label in the batch.
       centers: [class_num, 512], center points, each class have one.
       alfa: float, updating rate of centers.
    """
    label = tf.arg_max(label, 1)
    label = tf.reshape(label, [-1])
    centers_batch = tf.gather(centers, label)
    diff = alfa * (centers_batch - features)
    diff = diff / label_stats
    centers = tf.scatter_sub(centers, label, diff)
    loss = tf.nn.l2_loss(features - centers_batch)
    return loss, centers
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号