svm.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:pl-cnn 作者: oval-group 项目源码 文件源码
def max_oracle(scores,
                   y_truth):

        n_classes = scores.shape[1]
        t_range = T.arange(y_truth.shape[0])

        # classification loss for any combination
        losses = 1. - T.extra_ops.to_one_hot(y_truth, n_classes)

        # get max score for each sample
        y_star = T.argmax(scores + losses, axis=1)

        # compute classification loss for batch
        delta = losses[t_range, y_star].sum()

        return y_star, delta
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号