def forward(self, input): x = input if len(x.size()) == 3: x = x.view((1,)+x.size()) x = F.pad(x, self.padding, 'constant', self.value) x = x.view(x.size()[1:]) return x