def batch_norm_scattering(x, m,v): m=m.expand_as(x) v=v.expand_as(x) x = torch.div(torch.add(x,-m),v) return x