def updateGradInput(self, input, gradOutput):
M, v = input
self.gradInput[0].resize_as_(M)
self.gradInput[1].resize_as_(v)
assert gradOutput.ndimension() == 1 or gradOutput.ndimension() == 2
if gradOutput.ndimension() == 2:
assert M.ndimension() == 3
assert v.ndimension() == 2
bdim = M.size(0)
odim = M.size(1)
idim = M.size(2)
if self.trans:
torch.bmm(self.gradInput[0], v.view(bdim, odim, 1), gradOutput.view(bdim, 1, idim))
torch.bmm(self.gradInput[1].view(bdim, odim, 1), M, gradOutput.view(bdim, idim, 1))
else:
torch.bmm(self.gradInput[0], gradOutput.view(bdim, odim, 1), v.view(bdim, 1, idim))
torch.bmm(self.gradInput[1].view(bdim, idim, 1), M.transpose(1, 2), gradOutput.view(bdim, odim, 1))
else:
assert M.ndimension() == 2
assert v.ndimension() == 1
if self.trans:
torch.ger(self.gradInput[0], v, gradOutput)
self.gradInput[1] = M * gradOutput
else:
torch.ger(self.gradInput[0], gradOutput, v)
self.gradInput[1] = M.t() * gradOutput
return self.gradInput
评论列表
文章目录