def triplet_loss(infer, labels, radius = 2.0):
"""
Args:
infer: inference concatenate together with 2 * batch_size
labels: 0 or 1 with batch_size
radius:
Return:
loss: triplet loss
"""
feature_1, feature_2 = tf.split(0,2,infer)
feature_diff = tf.reduce_sum(tf.square(feature_1 - feature_2), 1)
feature_list = tf.dynamic_partition(feature_diff, labels, 2)
pos_list = feature_list[1]
neg_list = (tf.maximum(0.0, radius * radius - feature_list[0]))
full_list = tf.concat(0,[pos_list, neg_list])
loss = tf.reduce_mean(full_list)
return loss
评论列表
文章目录