fastgen.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:magenta 作者: tensorflow 项目源码 文件源码
def sample_categorical(pmf):
  """Sample from a categorical distribution.

  Args:
    pmf: Probablity mass function. Output of a softmax over categories.
      Array of shape [batch_size, number of categories]. Rows sum to 1.

  Returns:
    idxs: Array of size [batch_size, 1]. Integer of category sampled.
  """
  if pmf.ndim == 1:
    pmf = np.expand_dims(pmf, 0)
  batch_size = pmf.shape[0]
  cdf = np.cumsum(pmf, axis=1)
  rand_vals = np.random.rand(batch_size)
  idxs = np.zeros([batch_size, 1])
  for i in range(batch_size):
    idxs[i] = cdf[i].searchsorted(rand_vals[i])
  return idxs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号