def _sample_discrete_actions(batch_probs):
"""Sample a batch of actions from a batch of action probabilities.
Args:
batch_probs (ndarray): batch of action probabilities BxA
Returns:
List consisting of sampled actions
"""
action_indices = []
# Subtract a tiny value from probabilities in order to avoid
# "ValueError: sum(pvals[:-1]) > 1.0" in numpy.multinomial
batch_probs = batch_probs - np.finfo(np.float32).epsneg
for i in range(batch_probs.shape[0]):
histogram = np.random.multinomial(1, batch_probs[i])
action_indices.append(int(np.nonzero(histogram)[0]))
return action_indices
评论列表
文章目录