def lovasz_binary(margins, label, prox=False, max_steps=20, debug={}):
# 1d vector inputs
# Workaround: can't sort Variable bug
# prox: False or lambda regularization value
_, perm = torch.sort(margins.data, dim=0, descending=True)
margins_sorted = margins[perm]
grad = gamma_fast(label, perm)
loss = torch.dot(F.relu(margins_sorted), Variable(grad))
if prox is not False:
xp, gam = find_proximal(margins_sorted.data, grad, prox, max_steps=max_steps, eps=1e-6, debug=debug)
hook = margins_sorted.register_hook(lambda grad: Variable(margins_sorted.data - xp))
return loss, hook, gam
else:
return loss
评论列表
文章目录