models.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:keras_detect_tool_wear 作者: kidozh 项目源码 文件源码
def build_stateful_lstm_model_with_normalization(batch_size,
                                                 time_step,
                                                 input_dim,
                                                 output_dim,
                                                 dropout=0.2,
                                                 rnn_layer_num=2,
                                                 hidden_dim=128,
                                                 hidden_num=0,

                                                 rnn_type='LSTM'):

    model = Sequential()
    # may use BN for accelerating speed
    # add first LSTM
    if rnn_type == 'LSTM':
        rnn_cell = LSTM
    elif rnn_type == 'GRU':
        rnn_cell = GRU
    elif rnn_type == 'SimpleRNN':
        rnn_cell = SimpleRNN
    else:
        raise ValueError('Option rnn_type could only be configured as LSTM, GRU or SimpleRNN')
    model.add(rnn_cell(hidden_dim,return_sequences=True,batch_input_shape=(batch_size,time_step,input_dim)))
    model.add(BatchNormalization())

    for _ in range(rnn_layer_num-2):
        model.add(rnn_cell(hidden_dim, return_sequence=True))
        # prevent over fitting
        model.add(Dropout(dropout))


    model.add(BatchNormalization())
    model.add(rnn_cell(hidden_dim,return_sequences=False))

    # add hidden layer

    for _ in range(hidden_num):
        model.add(Dense(hidden_dim))

    model.add(Dropout(dropout))

    model.add(Dense(output_dim))

    rmsprop = RMSprop(lr=0.01)
    adam = Adam(lr=0.01)


    model.compile(loss='mse',metrics=['acc'],optimizer=rmsprop)

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号