def updateGradInput(self, input, gradOutput):
input, mask = input
if input.type() == 'torch.cuda.FloatTensor':
torch.range(self._maskIndexBufferCPU, 0, mask.nelement()-1).resize_(mask.size())
self._maskIndexBuffer.resize_(self._maskIndexBufferCPU.size()).copy_(self._maskIndexBufferCPU)
else:
torch.range(self._maskIndexBuffer, 0, mask.nelement()-1).resize_(mask.size())
torch.masked_select(self._maskIndices, self._maskIndexBuffer, mask)
self._gradBuffer.resize_(input.nelement()).zero_()
self._gradBuffer.scatter_(0, self._maskIndices, gradOutput)
self._gradBuffer.resize_(input.size())
self.gradInput = [self._gradBuffer, self._gradMask.resize_(mask.size()).fill_(0)]
return self.gradInput
评论列表
文章目录