def __call__(self, module): w = module.weight.data module.weight.data = th.renorm(w, 2, self.axis, self.value)