def torch_eye(n, m=None, out=None):
"""
Like `torch.eye()`, but works with cuda tensors.
"""
if m is None:
m = n
try:
return torch.eye(n, m, out=out)
except TypeError:
# Only catch errors due to torch.eye() not being available for cuda tensors.
module = torch.Tensor.__module__ if out is None else type(out).__module__
if module != 'torch.cuda':
raise
Tensor = getattr(torch, torch.Tensor.__name__)
cpu_out = Tensor(n, m)
cuda_out = torch.eye(m, n, out=cpu_out).cuda()
return cuda_out if out is None else out.copy_(cuda_out)
评论列表
文章目录