contrastive.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:pytorch-siamese 作者: delijati 项目源码 文件源码
def forward(self, x0, x1, y):
        # euclidian distance
        diff = x0 - x1
        dist_sq = torch.sum(torch.pow(diff, 2), 1)
        dist = torch.sqrt(dist_sq)

        mdist = self.margin - dist
        dist = torch.clamp(mdist, min=0.0)
        loss = y * dist_sq + (1 - y) * torch.pow(dist, 2)
        loss = torch.sum(loss) / 2.0 / x0.size()[0]
        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号