def param_mse(name, target): return torch.sum(torch.pow(target - pyro.param(name), 2.0)).data.cpu().numpy()[0]