def updateGradInput(self, input, gradOutput):
M, v = input
self.gradInput[0].resize_as_(M)
self.gradInput[1].resize_as_(v)
gradOutput = gradOutput.contiguous()
assert gradOutput.ndimension() == 1 or gradOutput.ndimension() == 2
if gradOutput.ndimension() == 2:
assert M.ndimension() == 3
assert v.ndimension() == 2
bdim = M.size(0)
odim = M.size(1)
idim = M.size(2)
if self.trans:
torch.bmm(v.view(bdim, odim, 1), gradOutput.view(bdim, 1, idim), out=self.gradInput[0])
torch.bmm(M, gradOutput.view(bdim, idim, 1), out=self.gradInput[1].view(bdim, odim, 1))
else:
torch.bmm(gradOutput.view(bdim, odim, 1), v.view(bdim, 1, idim), out=self.gradInput[0])
torch.bmm(M.transpose(1, 2), gradOutput.view(bdim, odim, 1), out=self.gradInput[1].view(bdim, idim, 1))
else:
assert M.ndimension() == 2
assert v.ndimension() == 1
if self.trans:
torch.ger(v, gradOutput, out=self.gradInput[0])
self.gradInput[1] = M * gradOutput
else:
torch.ger(gradOutput, v, out=self.gradInput[0])
self.gradInput[1] = M.t() * gradOutput
return self.gradInput
评论列表
文章目录