def find_proximal(x0, gam, lam, eps=1e-6, max_steps=20, debug={}):
# x0: sorted margins data
# gam: initial gamma_fast(target, perm)
# regularisation parameter lam
x = x0.clone()
act = (x >= eps).nonzero()
finished = False
if not act.size():
finished = True
else:
active = act[-1, 0]
members = {i: {i} for i in range(active + 1)}
if active > 0:
equal = (x[:active] - x[1:active+1]) < eps
for i, e in enumerate(equal):
if e:
members[i].update(members[i + 1])
members[i + 1] = members[i]
project(gam, active, members)
step = 0
while not finished and step < max_steps and active > -1:
step += 1
res = compute_step_length(x, gam, active, eps)
delta, ind = res
if ind == -1:
active = active - len(members[active])
stop = torch.dot(x - x0, gam) / torch.dot(gam, gam) + 1. / lam
if 0 <= stop < delta:
delta = stop
finished = True
x = x - delta * gam
if not finished:
if ind >= 0:
repr = min(members[ind])
members[repr].update(members[ind + 1])
for m in members[ind]:
if m != repr:
members[m] = members[repr]
project(gam, active, members)
if "path" in debug:
debug["path"].append(x.numpy())
if "step" in debug:
debug["step"] = step
if "finished" in debug:
debug["finished"] = finished
return x, gam
评论列表
文章目录