TripletLoss.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:pytorch-PersonReID 作者: huaijin-chen 项目源码 文件源码
def forward(self, anchor, positive, negative):
        #eucl distance
        #dist = torch.sum( (anchor - positive) ** 2 - (anchor - negative) ** 2, dim=1)\
        #        + self.margin

        if self.dist_type == 0:
            dist_p = F.pairwise_distance(anchor ,positive)
            dist_n = F.pairwise_distance(anchor ,negative)
        if self.dist_type == 1:
            dist_p = cosine_similarity(anchor, positive)
            disp_n = cosine_similarity(anchor, negative)


        dist_hinge = torch.clamp(dist_p - dist_n + self.margin, min=0.0)
        if self.use_ohem:
            v, idx = torch.sort(dist_hinge,descending=True)
            loss = torch.mean(v[0:self.ohem_bs])
        else:
            loss = torch.mean(dist_hinge)

        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号