def sample_from_probs(probs, top_n=10):
"""
truncated weighted random choice.
"""
_, indices = torch.sort(probs)
# set probabilities after top_n to 0
probs[indices.data[:-top_n]] = 0
sampled_index = torch.multinomial(probs, 1)
return sampled_index
评论列表
文章目录