init.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:pytorch 作者: tylergenter 项目源码 文件源码
def orthogonal(tensor, gain=1):
    """Fills the input Tensor or Variable with a (semi) orthogonal matrix, as described in "Exact solutions to the
    nonlinear dynamics of learning in deep linear neural networks" - Saxe, A. et al. (2013). The input tensor must have
    at least 2 dimensions, and for tensors with more than 2 dimensions the trailing dimensions are flattened.

    Args:
        tensor: an n-dimensional torch.Tensor or autograd.Variable, where n >= 2
        gain: optional scaling factor

    Examples:
        >>> w = torch.Tensor(3, 5)
        >>> nn.init.orthogonal(w)
    """
    if isinstance(tensor, Variable):
        orthogonal(tensor.data, gain=gain)
        return tensor

    if tensor.ndimension() < 2:
        raise ValueError("Only tensors with 2 or more dimensions are supported")

    rows = tensor.size(0)
    cols = tensor[0].numel()
    flattened = torch.Tensor(rows, cols).normal_(0, 1)
    # Compute the qr factorization
    q, r = torch.qr(flattened)
    # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
    d = torch.diag(r, 0)
    ph = d.sign()
    q *= ph.expand_as(q)
    # Pad zeros to Q (if rows smaller than cols)
    if rows < cols:
        padding = torch.zeros(rows, cols - rows)
        if q.is_cuda:
            q = torch.cat([q, padding.cuda()], 1)
        else:
            q = torch.cat([q, padding], 1)

    tensor.view_as(q).copy_(q)
    tensor.mul_(gain)
    return tensor
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号