def forward(self, anchor, positive, negative):
d_p = self.pdist.forward(anchor, positive)
d_n = self.pdist.forward(anchor, negative)
dist_hinge = torch.clamp(self.margin + d_p - d_n, min=0.0)
loss = torch.mean(dist_hinge)
return loss
评论列表
文章目录