pytorch_model.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:char-rnn-text-generation 作者: yxtay 项目源码 文件源码
def sample_from_probs(probs, top_n=10):
    """
    truncated weighted random choice.
    """
    _, indices = torch.sort(probs)
    # set probabilities after top_n to 0
    probs[indices.data[:-top_n]] = 0
    sampled_index = torch.multinomial(probs, 1)
    return sampled_index
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号