def weight_proj_l2norm(param): norm = torch.norm(param.data, p=2) + 1e-8 coeff = min(opt.wproj_upper, 1.0/norm) param.data.mul_(coeff) # custom weights initialization called on netG and netD