rand.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:CSB 作者: csb-toolbox 项目源码 文件源码
def truncated_normal(shape=None, mu=0., sigma=1., x_min=None, x_max=None):
    """
    Generates random variates from a lower-and upper-bounded normal distribution

    @param shape: shape of the random sample
    @param mu:    location parameter 
    @param sigma: width of the distribution (sigma >= 0.)
    @param x_min: lower bound of variate
    @param x_max: upper bound of variate    
    @return: random variates of lower-bounded normal distribution
    """
    from scipy.special import erf, erfinv
    from numpy.random import standard_normal
    from numpy import inf, sqrt

    if x_min is None and x_max is None:
        return standard_normal(shape) * sigma + mu
    elif x_min is None:
        x_min = -inf
    elif x_max is None:
        x_max = inf

    x_min = max(-1e300, x_min)
    x_max = min(+1e300, x_max)
    var = sigma ** 2 + 1e-300
    sigma = sqrt(2 * var)

    a = erf((x_min - mu) / sigma)
    b = erf((x_max - mu) / sigma)

    return probability_transform(shape, erfinv, a, b) * sigma + mu
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号