def updateGradInput(self, input, gradOutput):
if self.gradInput is None:
return
if self._div is None:
self._div = input.new()
if self._output is None:
self._output = self.output.new()
if self._expand4 is None:
self._expand4 = input.new()
if self._gradOutput is None:
self._gradOutput = input.new()
if not self.fastBackward:
self.updateOutput(input)
inputSize, outputSize = self.weight.size(0), self.weight.size(1)
"""
dy_j -2 * c_j * c_j * (w_j - x) c_j * c_j * (x - w_j)
---- = -------------------------- = ---------------------
dx 2 || c_j * (w_j - x) || y_j
"""
# to prevent div by zero (NaN) bugs
self._output.resize_as_(self.output).copy_(self.output).add_(1e-7)
self._view(self._gradOutput, gradOutput, gradOutput.size())
torch.div(gradOutput, self._output, out=self._div)
if input.dim() == 1:
self._div.resize_(1, outputSize)
self._expand4 = self._div.expand_as(self.weight)
if torch.type(input) == 'torch.cuda.FloatTensor':
self._repeat2.resize_as_(self._expand4).copy_(self._expand4)
self._repeat2.mul_(self._repeat)
else:
self._repeat2.mul_(self._repeat, self._expand4)
self._repeat2.mul_(self.diagCov)
torch.sum(self._repeat2, 1, True, out=self.gradInput)
self.gradInput.resize_as_(input)
elif input.dim() == 2:
batchSize = input.size(0)
self._div.resize_(batchSize, 1, outputSize)
self._expand4 = self._div.expand(batchSize, inputSize, outputSize)
if input.type() == 'torch.cuda.FloatTensor':
self._repeat2.resize_as_(self._expand4).copy_(self._expand4)
self._repeat2.mul_(self._repeat)
self._repeat2.mul_(self._repeat3)
else:
torch.mul(self._repeat, self._expand4, out=self._repeat2)
self._repeat2.mul_(self._expand3)
torch.sum(self._repeat2, 2, True, out=self.gradInput)
self.gradInput.resize_as_(input)
else:
raise RuntimeError("1D or 2D input expected")
return self.gradInput
评论列表
文章目录