def updateGradInput(self, input, y):
v1 = input[0]
v2 = input[1]
gw1 = self.gradInput[0]
gw2 = self.gradInput[1]
gw1.resize_as_(v1).copy_(v2)
gw2.resize_as_(v1).copy_(v1)
torch.mul(self.buffer, self.w1, self.w22)
gw1.addcmul_(-1, self.buffer.expand_as(v1), v1)
gw1.mul_(self.w.expand_as(v1))
torch.mul(self.buffer, self.w1, self.w32)
gw2.addcmul_(-1, self.buffer.expand_as(v1), v2)
gw2.mul_(self.w.expand_as(v1))
# self._idx = self._outputs <= 0
torch.le(self._idx, self._outputs, 0)
self._idx = self._idx.view(-1, 1).expand(gw1.size())
gw1[self._idx] = 0
gw2[self._idx] = 0
torch.eq(self._idx, y, 1)
self._idx = self._idx.view(-1, 1).expand(gw2.size())
gw1[self._idx] = gw1[self._idx].mul_(-1)
gw2[self._idx] = gw2[self._idx].mul_(-1)
if self.sizeAverage:
gw1.div_(y.size(0))
gw2.div_(y.size(0))
return self.gradInput
评论列表
文章目录