ListNet.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:learning2rank 作者: shiba24 项目源码 文件源码
def kld(self, vec_true, vec_compare):
        ind = vec_true.data * vec_compare.data > 0
        ind_var = chainer.Variable(ind)
        include_nan = vec_true * F.log(vec_true / vec_compare)
        z = chainer.Variable(np.zeros((len(ind), 1), dtype=np.float32))
        # return np.nansum(vec_true * np.log(vec_true / vec_compare))
        return F.sum(F.where(ind_var, include_nan, z))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号