def pairwise_ranking_loss(margin, x, v):
zero = torch.zeros(1)
diag_margin = margin * torch.eye(x.size(0))
if not args.no_cuda:
zero, diag_margin = zero.cuda(), diag_margin.cuda()
zero, diag_margin = Variable(zero), Variable(diag_margin)
x = x / torch.norm(x, 2, 1, keepdim=True)
v = v / torch.norm(v, 2, 1, keepdim=True)
prod = torch.matmul(x, v.transpose(0, 1))
diag = torch.diag(prod)
for_x = torch.max(zero, margin - torch.unsqueeze(diag, 1) + prod) - diag_margin
for_v = torch.max(zero, margin - torch.unsqueeze(diag, 0) + prod) - diag_margin
return (torch.sum(for_x) + torch.sum(for_v)) / x.size(0)
评论列表
文章目录