wavenet.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:eva 作者: israelg99 项目源码 文件源码
def Wavenet(input_shape, filters, depth, stacks, last=0, h=None, build=True):
    # TODO: Soft targets? A float to make targets a gaussian with stdev.
    # TODO: Train only receptive field. The temporal-first outputs are computed from zero-padding.
    # TODO: Global conditioning?
    # TODO: Local conditioning?

    _, nb_bins = input_shape

    input_audio = Input(input_shape, name='audio_input')

    model = CausalAtrousConvolution1D(filters, 2, mask_type='A', atrous_rate=1, border_mode='valid')(input_audio)

    out, skip_connections = WavenetBlocks(filters, depth, stacks)(model)

    out = Merge(mode='sum', name='merging_skips')(skip_connections)
    out = PReLU()(out)

    out = Convolution1D(nb_bins, 1, border_mode='same')(out)
    out = PReLU()(out)

    out = Convolution1D(nb_bins, 1, border_mode='same')(out)

    # https://storage.googleapis.com/deepmind-live-cms/documents/BlogPost-Fig2-Anim-160908-r01.gif
    if last > 0:
        out = Lambda(lambda x: x[:, -last:], output_shape=(last, out._keras_shape[2]), name='last_out')(out)

    out = Activation('softmax')(out)

    if build:
        model = Model(input_audio, out)
        model.compile(Nadam(), 'sparse_categorical_crossentropy')

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号