tripletnet.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:TripletEmbedding 作者: hrantzsch 项目源码 文件源码
def __call__(self, x, margin_factor=1.0, train=True):
        """
        Embed samples using the CNN, then calculate distances and triplet loss.

        x is a batch of size 3n following the form:

        | anchor_1   |
        | [...]      |
        | anchor_n   |
        | positive_1 |
        | [...]      |
        | positive_n |
        | negative_1 |
        | [...]      |
        | negative_n |
        """
        anc, pos, neg = (self.embed(h) for h in F.split_axis(x, 3, 0))
        dist_pos, dist_neg = self.squared_distance(anc, pos, neg)
        mf = margin_factor if train else 1.0  # no margin when testing
        return self.compute_loss(dist_pos, dist_neg, mf)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号