def forward(self, inputs, targets):
n = inputs.size(0)
# Compute pairwise distance, replace by the official when merged
dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
dist = dist + dist.t()
dist.addmm_(1, -2, inputs, inputs.t())
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
# For each anchor, find the hardest positive and negative
mask = targets.expand(n, n).eq(targets.expand(n, n).t())
dist_ap, dist_an = [], []
for i in range(n):
dist_ap.append(dist[i][mask[i]].max())
dist_an.append(dist[i][mask[i] == 0].min())
dist_ap = torch.cat(dist_ap)
dist_an = torch.cat(dist_an)
# Compute ranking hinge loss
y = dist_an.data.new()
y.resize_as_(dist_an.data)
y.fill_(1)
y = Variable(y)
loss = self.ranking_loss(dist_an, dist_ap, y)
prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0)
return loss, prec
评论列表
文章目录