def erfinv_approx(self, x): b = -2 / (math.pi * self.a_for_erf) - torch.log(1 - x * x) / 2 return torch.sign(x) * torch.sqrt(b + torch.sqrt(b * b - torch.log(1 - x * x) / self.a_for_erf))