wavent.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:sampleRNN_ICLR2017 作者: soroushmehr 项目源码 文件源码
def create_wavenet_block(inp, num_dilation_layer, input_dim, output_dim, name =None):
    assert name is not None
    layer_out = inp
    skip_contrib = []
    skip_weights = lib.param(name+".parametrized_weights", lib.floatX(numpy.ones((num_dilation_layer,))))
    for i in range(num_dilation_layer):
        layer_out, skip_c = lib.ops.dil_conv_1D(
                    layer_out,
                    output_dim,
                    input_dim if i == 0 else output_dim,
                    2,
                    dilation = 2**i,
                    non_linearity = 'gated',
                    name = name+".dilation_{}".format(i+1)
                )
        skip_c = skip_c*skip_weights[i]

        skip_contrib.append(skip_c)

    skip_out =  skip_contrib[-1]

    j = 0
    for i in range(num_dilation_layer-1):
        j += 2**(num_dilation_layer-i-1)
        skip_out = skip_out + skip_contrib[num_dilation_layer-2 - i][:,j:]

    return layer_out, skip_out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号