def forward(self, anchor, positive, negative):
#eucl distance
#dist = torch.sum( (anchor - positive) ** 2 - (anchor - negative) ** 2, dim=1)\
# + self.margin
if self.dist_type == 0:
dist_p = F.pairwise_distance(anchor ,positive)
dist_n = F.pairwise_distance(anchor ,negative)
if self.dist_type == 1:
dist_p = cosine_similarity(anchor, positive)
disp_n = cosine_similarity(anchor, negative)
dist_hinge = torch.clamp(dist_p - dist_n + self.margin, min=0.0)
if self.use_ohem:
v, idx = torch.sort(dist_hinge,descending=True)
loss = torch.mean(v[0:self.ohem_bs])
else:
loss = torch.mean(dist_hinge)
return loss
评论列表
文章目录