assemble.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:keras-gp 作者: alshedivat 项目源码 文件源码
def assemble_rnn(params, final_reshape=True):
    """Construct an RNN/LSTM/GRU model of the form: X-[H1-H2-...-HN]-Y.
    All the H-layers are optional recurrent layers and depend on whether they
    are specified in the params dictionary.
    """
    # Input layer
    input_shape = params['input_shape']
    inputs = layers.Input(shape=input_shape)
    # inputs = layers.Input(batch_shape=[20] + list(input_shape))

    # Masking layer
    previous = layers.Masking(mask_value=0.0)(inputs)

    # Hidden layers
    for layer in params['hidden_layers']:
        Layer = layers.deserialize(
            {'class_name': layer['name'], 'config': layer['config']})
        previous = Layer(previous)
        if 'dropout' in layer and layer['dropout'] is not None:
            previous = layers.Dropout(layer['dropout'])(previous)
        if 'batch_norm' in layer and layer['batch_norm'] is not None:
            previous = layers.BatchNormalization(**layer['batch_norm'])(previous)

    # Output layer
    output_shape = params['output_shape']
    output_dim = np.prod(output_shape)
    outputs = layers.Dense(output_dim)(previous)

    if final_reshape:
        outputs = layers.Reshape(output_shape)(outputs)

    return KerasModel(inputs=inputs, outputs=outputs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号