losses.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:jaccardSegment 作者: bermanmaxim 项目源码 文件源码
def lovasz_binary(margins, label, prox=False, max_steps=20, debug={}):
    # 1d vector inputs
    # Workaround: can't sort Variable bug
    # prox: False or lambda regularization value
    _, perm = torch.sort(margins.data, dim=0, descending=True)
    margins_sorted = margins[perm]
    grad = gamma_fast(label, perm)
    loss = torch.dot(F.relu(margins_sorted), Variable(grad))
    if prox is not False:
        xp, gam = find_proximal(margins_sorted.data, grad, prox, max_steps=max_steps, eps=1e-6, debug=debug)
        hook = margins_sorted.register_hook(lambda grad: Variable(margins_sorted.data - xp))
        return loss, hook, gam
    else:
        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号