def __call__(self, x, margin_factor=1.0, train=True):
"""
Embed samples using the CNN, then calculate distances and triplet loss.
x is a batch of size 3n following the form:
| anchor_1 |
| [...] |
| anchor_n |
| positive_1 |
| [...] |
| positive_n |
| negative_1 |
| [...] |
| negative_n |
"""
anc, pos, neg = (self.embed(h) for h in F.split_axis(x, 3, 0))
dist_pos, dist_neg = self.squared_distance(anc, pos, neg)
mf = margin_factor if train else 1.0 # no margin when testing
return self.compute_loss(dist_pos, dist_neg, mf)
评论列表
文章目录