def erf(x: T.FloatTensor) -> T.FloatTensor:
"""
Elementwise error function of a tensor.
Args:
x: A tensor.
Returns:
tensor: Elementwise error function
"""
a = 8.0/(3.0*pi)*(pi-3.0)/(4.0-pi)
x_sq = x*x
return torch.sign(x)*torch.sqrt(1-torch.exp(-x_sq*(4/pi+a*x_sq)/(1+a*x_sq)))
评论列表
文章目录