def cosine_ranking_loss(input_data, ctx, margin=0.1):
"""
:param input_data: [batch_size, 300] tensor of predictions
:param ctx: [batch_size, 300] tensor of ground truths
:param margin: Difference between them
:return:
"""
normed = _normalize(input_data)
ctx_normed = _normalize(ctx)
shuff_inds = torch.randperm(normed.size(0))
if ctx.is_cuda:
shuff_inds = shuff_inds.cuda()
shuff = ctx_normed[shuff_inds]
correct_contrib = torch.sum(normed * ctx_normed, 1).squeeze()
incorrect_contrib = torch.sum(normed * shuff, 1).squeeze()
# similarity = torch.mm(normed, ctx_normed.t()) #[predictions, gts]
# correct_contrib = similarity.diag()
# incorrect_contrib = incorrect_contrib.sum(1).squeeze()/(incorrect_contrib.size(1)-1.0)
#
cost = (0.1 + incorrect_contrib-correct_contrib).clamp(min=0)
return cost, correct_contrib, incorrect_contrib
评论列表
文章目录