def sample_from_probs2(probs, out=None):
"""Sample from multiple vectors of non-normalized probabilities.
Args:
probs: An [N, M]-shaped numpy array of non-normalized probabilities.
out: An optional destination for the result.
Returns:
An [N]-shaped numpy array of integers in range(M).
"""
# Adapted from https://stackoverflow.com/questions/40474436
assert len(probs.shape) == 2
cdf = probs.cumsum(axis=1)
u = np.random.rand(probs.shape[0], 1) * cdf[:, -1, np.newaxis]
return (u < cdf).argmax(axis=1, out=out)
评论列表
文章目录