MV.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def updateGradInput(self, input, gradOutput):
        M, v = input
        self.gradInput[0].resize_as_(M)
        self.gradInput[1].resize_as_(v)
        gradOutput = gradOutput.contiguous()

        assert gradOutput.ndimension() == 1 or gradOutput.ndimension() == 2

        if gradOutput.ndimension() == 2:
            assert M.ndimension() == 3
            assert v.ndimension() == 2
            bdim = M.size(0)
            odim = M.size(1)
            idim = M.size(2)

            if self.trans:
                torch.bmm(v.view(bdim, odim, 1), gradOutput.view(bdim, 1, idim), out=self.gradInput[0])
                torch.bmm(M, gradOutput.view(bdim, idim, 1), out=self.gradInput[1].view(bdim, odim, 1))
            else:
                torch.bmm(gradOutput.view(bdim, odim, 1), v.view(bdim, 1, idim), out=self.gradInput[0])
                torch.bmm(M.transpose(1, 2), gradOutput.view(bdim, odim, 1), out=self.gradInput[1].view(bdim, idim, 1))
        else:
            assert M.ndimension() == 2
            assert v.ndimension() == 1

            if self.trans:
                torch.ger(v, gradOutput, out=self.gradInput[0])
                self.gradInput[1] = M * gradOutput
            else:
                torch.ger(gradOutput, v, out=self.gradInput[0])
                self.gradInput[1] = M.t() * gradOutput

        return self.gradInput
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号