def updateOutput(self, input):
outputSize = list(input.size())
outputSize[self.dim] += abs(self.pad)
self.outputSize = torch.Size(outputSize)
dim = self.dim
self.output.resize_(self.outputSize)
self.output.fill_(self.value)
index = self.index
pad = self.pad
if pad > 0:
index = input.size(dim) - index
else:
pad = -pad
if index == 0:
self.output.narrow(dim, pad, input.size(dim)).copy_(input)
elif index == input.size(dim):
self.output.narrow(dim, 0, input.size(dim)).copy_(input)
else:
self.output.narrow(dim, 0, index).copy_(input.narrow(dim, 0, index))
self.output.narrow(dim, index + pad, input.size(dim) -
index).copy_(input.narrow(dim, index, input.size(dim) - index))
return self.output
评论列表
文章目录