util.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:treecat 作者: posterior 项目源码 文件源码
def sample_from_probs2(probs, out=None):
    """Sample from multiple vectors of non-normalized probabilities.

    Args:
        probs: An [N, M]-shaped numpy array of non-normalized probabilities.
        out: An optional destination for the result.

    Returns:
        An [N]-shaped numpy array of integers in range(M).
    """
    # Adapted from https://stackoverflow.com/questions/40474436
    assert len(probs.shape) == 2
    cdf = probs.cumsum(axis=1)
    u = np.random.rand(probs.shape[0], 1) * cdf[:, -1, np.newaxis]
    return (u < cdf).argmax(axis=1, out=out)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号