init.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:pytorch 作者: pytorch 项目源码 文件源码
def eye(tensor):
    """Fills the 2-dimensional input Tensor or Variable with the identity
    matrix. Preserves the identity of the inputs in Linear layers, where as
    many inputs are preserved as possible.

    Args:
        tensor: a 2-dimensional torch.Tensor or autograd.Variable

    Examples:
        >>> w = torch.Tensor(3, 5)
        >>> nn.init.eye(w)
    """
    if tensor.ndimension() != 2:
        raise ValueError("Only tensors with 2 dimensions are supported")

    if isinstance(tensor, Variable):
        eye(tensor.data)
        return tensor

    return tensor.copy_(torch.eye(tensor.size(0), tensor.size(1)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号