def create_wavenet_block(inp, num_dilation_layer, input_dim, output_dim, name =None):
assert name is not None
layer_out = inp
skip_contrib = []
skip_weights = lib.param(name+".parametrized_weights", lib.floatX(numpy.ones((num_dilation_layer,))))
for i in range(num_dilation_layer):
layer_out, skip_c = lib.ops.dil_conv_1D(
layer_out,
output_dim,
input_dim if i == 0 else output_dim,
2,
dilation = 2**i,
non_linearity = 'gated',
name = name+".dilation_{}".format(i+1)
)
skip_c = skip_c*skip_weights[i]
skip_contrib.append(skip_c)
skip_out = skip_contrib[-1]
j = 0
for i in range(num_dilation_layer-1):
j += 2**(num_dilation_layer-i-1)
skip_out = skip_out + skip_contrib[num_dilation_layer-2 - i][:,j:]
return layer_out, skip_out
评论列表
文章目录