def sparse_eye(size):
"""
Returns the identity matrix as a sparse matrix
"""
indices = torch.arange(0, size).long().unsqueeze(0).expand(2, size)
values = torch.Tensor([1]).expand(size)
return torch.sparse.FloatTensor(indices, values, torch.Size([size, size]))
评论列表
文章目录