util.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def torch_eye(n, m=None, out=None):
    """
    Like `torch.eye()`, but works with cuda tensors.
    """
    if m is None:
        m = n
    try:
        return torch.eye(n, m, out=out)
    except TypeError:
        # Only catch errors due to torch.eye() not being available for cuda tensors.
        module = torch.Tensor.__module__ if out is None else type(out).__module__
        if module != 'torch.cuda':
            raise
    Tensor = getattr(torch, torch.Tensor.__name__)
    cpu_out = Tensor(n, m)
    cuda_out = torch.eye(m, n, out=cpu_out).cuda()
    return cuda_out if out is None else out.copy_(cuda_out)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号