wavenet.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:wavenet 作者: musyoku 项目源码 文件源码
def __call__(self, x):
        batchsize = x.data.shape[0]
        input_x_width = x.data.shape[3]

        if self.dilation == 1:
            # perform normal convolution
            padded_x = self.padding_1d(x, self.filter_width - 1)
            return super(DilatedConvolution1D, self).__call__(padded_x)

        # padding
        pad = 0
        padded_x_width = input_x_width

        ## check if we can reshape
        mod = padded_x_width % self.dilation
        if mod > 0:
            pad += self.dilation - mod
            padded_x_width = input_x_width + pad

        ## check if height < filter width
        height = padded_x_width / self.dilation
        if height < self.filter_width:
            pad += (self.filter_width - height) * self.dilation
            padded_x_width = input_x_width + pad

        if pad > 0:
            padded_x = self.padding_1d(x, pad)
        else:
            padded_x = x

        # to skip (dilation - 1) elements
        padded_x = F.reshape(padded_x, (batchsize, self.in_channels, -1, self.dilation))
        # we can remove transpose operation when residual_conv_filter_width is set to the kernel's height
        # padded_x = F.transpose(padded_x, (0, 1, 3, 2))

        # convolution
        out = super(DilatedConvolution1D, self).__call__(padded_x)

        # reshape to the original shape
        out = F.reshape(out, (batchsize, self.out_channels, 1, -1))

        # remove padded elements / add missing elements
        cut = out.data.shape[3] - input_x_width
        if cut > 0:
            out = self.slice_1d(out, cut)
        elif cut < 0:
            out = self.padding_1d(out, -cut)

        return out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号