def sample_discrete_actions(batch_probs):
"""Sample a batch of actions from a batch of action probabilities.
Args:
batch_probs (ndarray): batch of action probabilities BxA
Returns:
ndarray consisting of sampled action indices
"""
xp = chainer.cuda.get_array_module(batch_probs)
return xp.argmax(
xp.log(batch_probs) + xp.random.gumbel(size=batch_probs.shape),
axis=1).astype(np.int32, copy=False)
评论列表
文章目录