def enforce_angle(ang, xnorm, target, margin=0, linearized=False):
""" Enforce _real_ angular margin"""
m = margin + 1 # !! Just to keep parameters consistent w/ enforce_angle
tmp = torch.gather(ang, 1, target.view(-1, 1)).mul(m)
ang = ang.scatter(1, target.view(-1, 1), tmp)
ang = psi(ang, linearized)
ang = ang.mul(xnorm.view(-1, 1).expand_as(ang))
return ang
评论列表
文章目录