def updateGradInput(self, input, gradOutput):
v1 = input[0]
v2 = input[1]
v1, v2 = self._makeContiguous(v1, v2)
if len(self.gradInput) != 2:
self.gradInput[0] = self.gradInput[0] or v1.new()
self.gradInput[1] = self.gradInput[1] or v1.new()
self.gradInput = self.gradInput[:2]
gw1 = self.gradInput[0]
gw2 = self.gradInput[1]
gw1.resize_as_(v1).copy_(v2)
gw2.resize_as_(v1).copy_(v1)
torch.mul(self.buffer, self.w1, self.w22)
gw1.addcmul_(-1, self.buffer.expand_as(v1), v1)
gw1.mul_(self.w.expand_as(v1))
torch.mul(self.buffer, self.w1, self.w32)
gw2.addcmul_(-1, self.buffer.expand_as(v1), v2)
gw2.mul_(self.w.expand_as(v1))
go = gradOutput.view(-1, 1).expand_as(v1)
gw1.mul_(go)
gw2.mul_(go)
return self.gradInput
评论列表
文章目录