def updateGradInput(self, input, gradOutput):
if not self.gradInput:
return
self._div = self._div or input.new()
self._output = self._output or self.output.new()
self._gradOutput = self._gradOutput or input.new()
self._expand3 = self._expand3 or input.new()
if not self.fastBackward:
self.updateOutput(input)
inputSize, outputSize = self.weight.size(0), self.weight.size(1)
"""
dy_j -2 * (w_j - x) x - w_j
---- = ---------------- = -------
dx 2 || w_j - x || y_j
"""
# to prevent div by zero (NaN) bugs
self._output.resize_as_(self.output).copy_(self.output).add_(0.0000001)
self._view(self._gradOutput, gradOutput, gradOutput.size())
torch.div(self._div, gradOutput, self._output)
assert input.dim() == 2
batchSize = input.size(0)
self._div.resize_(batchSize, 1, outputSize)
self._expand3 = self._div.expand(batchSize, inputSize, outputSize)
if torch.typename(input) == 'torch.cuda.FloatTensor':
self._repeat2.resize_as_(self._expand3).copy_(self._expand3)
self._repeat2.mul_(self._repeat)
else:
torch.mul(self._repeat2, self._repeat, self._expand3)
torch.sum(self.gradInput, self._repeat2, 2)
self.gradInput.resize_as_(input)
return self.gradInput
评论列表
文章目录