def __call__(self, x):
batchsize = x.data.shape[0]
input_x_width = x.data.shape[3]
if self.dilation == 1:
# perform normal convolution
padded_x = self.padding_1d(x, self.filter_width - 1)
return super(DilatedConvolution1D, self).__call__(padded_x)
# padding
pad = 0
padded_x_width = input_x_width
## check if we can reshape
mod = padded_x_width % self.dilation
if mod > 0:
pad += self.dilation - mod
padded_x_width = input_x_width + pad
## check if height < filter width
height = padded_x_width / self.dilation
if height < self.filter_width:
pad += (self.filter_width - height) * self.dilation
padded_x_width = input_x_width + pad
if pad > 0:
padded_x = self.padding_1d(x, pad)
else:
padded_x = x
# to skip (dilation - 1) elements
padded_x = F.reshape(padded_x, (batchsize, self.in_channels, -1, self.dilation))
# we can remove transpose operation when residual_conv_filter_width is set to the kernel's height
# padded_x = F.transpose(padded_x, (0, 1, 3, 2))
# convolution
out = super(DilatedConvolution1D, self).__call__(padded_x)
# reshape to the original shape
out = F.reshape(out, (batchsize, self.out_channels, 1, -1))
# remove padded elements / add missing elements
cut = out.data.shape[3] - input_x_width
if cut > 0:
out = self.slice_1d(out, cut)
elif cut < 0:
out = self.padding_1d(out, -cut)
return out
评论列表
文章目录