def updateGradInput(self, input, gradOutput):
self.gradInput.resize_as_(input).zero_()
size = list(input.size())
size.insert(self.dim, 1)
gradInput = self.gradInput.view(*size)
torch.sum(gradInput, gradOutput, self.dim)
return self.gradInput
评论列表
文章目录