net.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:RankNet 作者: szdr 项目源码 文件源码
def __call__(self, x_i, x_j, t_i, t_j):
        s_i = self.predictor(x_i)
        s_j = self.predictor(x_j)
        s_diff = s_i - s_j
        if t_i.data > t_j.data:
            S_ij = 1
        elif t_i.data < t_j.data:
            S_ij = -1
        else:
            S_ij = 0
        self.loss = (1 - S_ij) * s_diff / 2. + F.log(1 + F.exp(-s_diff))
        return self.loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号