def eye(tensor):
"""Fills the 2-dimensional input Tensor or Variable with the identity
matrix. Preserves the identity of the inputs in Linear layers, where as
many inputs are preserved as possible.
Args:
tensor: a 2-dimensional torch.Tensor or autograd.Variable
Examples:
>>> w = torch.Tensor(3, 5)
>>> nn.init.eye(w)
"""
if tensor.ndimension() != 2:
raise ValueError("Only tensors with 2 dimensions are supported")
if isinstance(tensor, Variable):
eye(tensor.data)
return tensor
return tensor.copy_(torch.eye(tensor.size(0), tensor.size(1)))
评论列表
文章目录