model_func.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:traffic_video_analysis 作者: polltooh 项目源码 文件源码
def triplet_loss(infer, labels, radius = 2.0):
    """
    Args:
        infer: inference concatenate together with 2 * batch_size
        labels: 0 or 1 with batch_size
        radius:
    Return:
        loss: triplet loss
    """

    feature_1, feature_2 = tf.split(0,2,infer)

    feature_diff = tf.reduce_sum(tf.square(feature_1 - feature_2), 1)
    feature_list = tf.dynamic_partition(feature_diff, labels, 2)

    pos_list = feature_list[1]
    neg_list  = (tf.maximum(0.0, radius * radius - feature_list[0]))
    full_list = tf.concat(0,[pos_list, neg_list])
    loss = tf.reduce_mean(full_list)

    return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号