def normalize(data, p=2, dim=1, eps=1e-12): return data / torch.norm(data, p, dim).clamp(min=eps).expand_as(data)