def truncated_normal(shape=None, mu=0., sigma=1., x_min=None, x_max=None):
"""
Generates random variates from a lower-and upper-bounded normal distribution
@param shape: shape of the random sample
@param mu: location parameter
@param sigma: width of the distribution (sigma >= 0.)
@param x_min: lower bound of variate
@param x_max: upper bound of variate
@return: random variates of lower-bounded normal distribution
"""
from scipy.special import erf, erfinv
from numpy.random import standard_normal
from numpy import inf, sqrt
if x_min is None and x_max is None:
return standard_normal(shape) * sigma + mu
elif x_min is None:
x_min = -inf
elif x_max is None:
x_max = inf
x_min = max(-1e300, x_min)
x_max = min(+1e300, x_max)
var = sigma ** 2 + 1e-300
sigma = sqrt(2 * var)
a = erf((x_min - mu) / sigma)
b = erf((x_max - mu) / sigma)
return probability_transform(shape, erfinv, a, b) * sigma + mu
评论列表
文章目录