wavenet.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:WaveNet-Theano 作者: huyouare 项目源码 文件源码
def DilatedConv1D(name, input_dim, output_dim, filter_size, inputs, dilation, mask_type=None, apply_biases=True):
    """
    inputs.shape: (batch size, length, input_dim)
    mask_type: None, 'a', 'b'
    output.shape: (batch size, length, output_dim)
    """
    def uniform(stdev, size):
        """uniform distribution with the given stdev and size"""
        return numpy.random.uniform(
            low=-stdev * numpy.sqrt(3),
            high=stdev * numpy.sqrt(3),
            size=size
        ).astype(theano.config.floatX)

    filters_init = uniform(
        1./numpy.sqrt(input_dim * filter_size),
        # output dim, input dim, height, width
        (output_dim, input_dim, filter_size, 1)
    )

    if mask_type is not None:
        filters_init *= lib.floatX(numpy.sqrt(2.))

    filters = lib.param(
        name+'.Filters',
        filters_init
    )

    if mask_type is not None:
        mask = numpy.ones(
            (output_dim, input_dim, filter_size, 1),
            dtype=theano.config.floatX
        )

        center = filter_size//2
        for i in xrange(filter_size):
            if (i > center):
                mask[:, :, i, :] = 0.
            # if (mask_type=='a' and i == center):
            #     mask[:, :, center] = 0.
        filters = filters * mask

    inputs = inputs.reshape((inputs.shape[0], inputs.shape[1], 1, inputs.shape[2]))
    # conv2d takes inputs as (batch size, input channels, height[?], width[?])
    inputs = inputs.dimshuffle(0, 3, 1, 2)
    result = T.nnet.conv2d(inputs, filters, border_mode='half', filter_flip=False, filter_dilation=(dilation, 1))

    if apply_biases:
        biases = lib.param(
            name+'.Biases',
            numpy.zeros(output_dim, dtype=theano.config.floatX)
        )
        result = result + biases[None, :, None, None]

    result = result.dimshuffle(0, 2, 3, 1)
    return result.reshape((result.shape[0], result.shape[1], result.shape[3]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号