def sample_categorical(pmf):
"""Sample from a categorical distribution.
Args:
pmf: Probablity mass function. Output of a softmax over categories.
Array of shape [batch_size, number of categories]. Rows sum to 1.
Returns:
idxs: Array of size [batch_size, 1]. Integer of category sampled.
"""
if pmf.ndim == 1:
pmf = np.expand_dims(pmf, 0)
batch_size = pmf.shape[0]
cdf = np.cumsum(pmf, axis=1)
rand_vals = np.random.rand(batch_size)
idxs = np.zeros([batch_size, 1])
for i in range(batch_size):
idxs[i] = cdf[i].searchsorted(rand_vals[i])
return idxs
评论列表
文章目录