def build_model(fragment_length, nb_filters, nb_output_bins, dilation_depth, nb_stacks, use_skip_connections,
learn_all_outputs, _log, desired_sample_rate, use_bias, res_l2, final_l2):
def residual_block(x):
original_x = x
# TODO: initalization, regularization?
# Note: The AtrousConvolution1D with the 'causal' flag is implemented in github.com/basveeling/keras#@wavenet.
tanh_out = CausalAtrousConvolution1D(nb_filters, 2, atrous_rate=2 ** i, border_mode='valid', causal=True,
bias=use_bias,
name='dilated_conv_%d_tanh_s%d' % (2 ** i, s), activation='tanh',
W_regularizer=l2(res_l2))(x)
sigm_out = CausalAtrousConvolution1D(nb_filters, 2, atrous_rate=2 ** i, border_mode='valid', causal=True,
bias=use_bias,
name='dilated_conv_%d_sigm_s%d' % (2 ** i, s), activation='sigmoid',
W_regularizer=l2(res_l2))(x)
x = layers.Merge(mode='mul', name='gated_activation_%d_s%d' % (i, s))([tanh_out, sigm_out])
res_x = layers.Convolution1D(nb_filters, 1, border_mode='same', bias=use_bias,
W_regularizer=l2(res_l2))(x)
skip_x = layers.Convolution1D(nb_filters, 1, border_mode='same', bias=use_bias,
W_regularizer=l2(res_l2))(x)
res_x = layers.Merge(mode='sum')([original_x, res_x])
return res_x, skip_x
input = Input(shape=(fragment_length, nb_output_bins), name='input_part')
out = input
skip_connections = []
out = CausalAtrousConvolution1D(nb_filters, 2, atrous_rate=1, border_mode='valid', causal=True,
name='initial_causal_conv')(out)
for s in range(nb_stacks):
for i in range(0, dilation_depth + 1):
out, skip_out = residual_block(out)
skip_connections.append(skip_out)
if use_skip_connections:
out = layers.Merge(mode='sum')(skip_connections)
out = layers.Activation('relu')(out)
out = layers.Convolution1D(nb_output_bins, 1, border_mode='same',
W_regularizer=l2(final_l2))(out)
out = layers.Activation('relu')(out)
out = layers.Convolution1D(nb_output_bins, 1, border_mode='same')(out)
if not learn_all_outputs:
raise DeprecationWarning('Learning on just all outputs is wasteful, now learning only inside receptive field.')
out = layers.Lambda(lambda x: x[:, -1, :], output_shape=(out._keras_shape[-1],))(
out) # Based on gif in deepmind blog: take last output?
out = layers.Activation('softmax', name="output_softmax")(out)
model = Model(input, out)
receptive_field, receptive_field_ms = compute_receptive_field()
_log.info('Receptive Field: %d (%dms)' % (receptive_field, int(receptive_field_ms)))
return model
评论列表
文章目录