def accGradParameters(self, input, gradOutput, scale=1):
if self._input is None:
self._input = input.new()
self._gradWeight = input.new()
self._sum = input.new()
batchSize = input.size(0)
contiguousView(self._input, input, batchSize, -1)
contiguousView(self._gradOutput, gradOutput, batchSize, -1)
self._gradWeight = self.gradWeight.view(1, -1)
torch.mul(self._repeat, self._input, self._gradOutput)
torch.sum(self._sum, self._repeat, 0)
self._gradWeight.add_(scale, self._sum)
评论列表
文章目录