MV.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def updateGradInput(self, input, gradOutput):
        M, v = input
        self.gradInput[0].resize_as_(M)
        self.gradInput[1].resize_as_(v)

        assert gradOutput.ndimension() == 1 or gradOutput.ndimension() == 2

        if gradOutput.ndimension() == 2:
            assert M.ndimension() == 3
            assert v.ndimension() == 2
            bdim = M.size(0)
            odim = M.size(1)
            idim = M.size(2)

            if self.trans:
                torch.bmm(self.gradInput[0], v.view(bdim, odim, 1), gradOutput.view(bdim, 1, idim))
                torch.bmm(self.gradInput[1].view(bdim, odim, 1), M, gradOutput.view(bdim, idim, 1))
            else:
                torch.bmm(self.gradInput[0], gradOutput.view(bdim, odim, 1), v.view(bdim, 1, idim))
                torch.bmm(self.gradInput[1].view(bdim, idim, 1), M.transpose(1, 2), gradOutput.view(bdim, odim, 1))
        else:
            assert M.ndimension() == 2
            assert v.ndimension() == 1

            if self.trans:
                torch.ger(self.gradInput[0], v, gradOutput)
                self.gradInput[1] = M * gradOutput
            else:
                torch.ger(self.gradInput[0], gradOutput, v)
                self.gradInput[1] = M.t() * gradOutput

        return self.gradInput
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号