tripletnet.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:TripletEmbedding 作者: hrantzsch 项目源码 文件源码
def compute_loss(self, dist_pos, dist_neg, margin_factor=1.0):
        """
        Use Softmax on the distances as a ratio measure and compare it to a
        vector of [[0, 0, ...] [1, 1, ...]] (Mean Squared Error).
        This function also computes the accuracy and the 'max_distance'.
        """
        # apply margin factor and take square root
        dist = sqrt(F.concat((dist_pos * margin_factor, dist_neg)))

        sm = F.softmax(dist)
        self.loss = mse_zero_one(sm)
        self.accuracy = self._accuracy(dist_pos, dist_neg)
        self.mean_diff = self._mean_difference(dist_pos, dist_neg)
        self.max_diff = self._max_difference(dist_pos, dist_neg)

        return self.loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号