def erfinv(x: T.FloatTensor) -> T.FloatTensor:
"""
Elementwise error function of a tensor.
Args:
x: A tensor.
Returns:
tensor: Elementwise error function
"""
a = 8.0/(3.0*pi)*(pi-3.0)/(4.0-pi)
x_sq = x*x
b = -2/(pi*a)-torch.log(1-x_sq)/2
return torch.sign(x)*torch.sqrt(b+torch.sqrt(b*b-torch.log(1-x_sq)/a))
评论列表
文章目录