def updateGradInput(self, input, gradOutput):
input, mask = input
if input.type() == 'torch.cuda.FloatTensor':
torch.arange(0, mask.nelement(), out=self._maskIndexBufferCPU).resize_(mask.size())
self._maskIndexBuffer.resize_(self._maskIndexBufferCPU.size()).copy_(self._maskIndexBufferCPU)
else:
torch.arange(0, mask.nelement(), out=self._maskIndexBuffer).resize_(mask.size())
torch.masked_select(self._maskIndexBuffer, mask, out=self._maskIndices)
self._gradBuffer.resize_(input.nelement()).zero_()
self._gradBuffer.scatter_(0, self._maskIndices, gradOutput)
self._gradBuffer.resize_(input.size())
self.gradInput = [self._gradBuffer, self._gradMask.resize_(mask.size()).fill_(0)]
return self.gradInput
评论列表
文章目录