util.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def torch_multinomial(input, num_samples, replacement=False):
    """
    Like `torch.multinomial()` but works with cuda tensors.
    Does not support keyword argument `out`.
    """
    if input.is_cuda:
        return torch_multinomial(input.cpu(), num_samples, replacement).cuda()
    else:
        return torch.multinomial(input, num_samples, replacement)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号