models.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:keras_detect_tool_wear 作者: kidozh 项目源码 文件源码
def build_stateful_lstm_model(batch_size,time_step,input_dim,output_dim,dropout=0.2,rnn_layer_num=2,hidden_dim=128,hidden_num=0,rnn_type='LSTM'):

    model = Sequential()
    # may use BN for accelerating speed
    # add first LSTM
    if rnn_type == 'LSTM':
        rnn_cell = LSTM
    elif rnn_type == 'GRU':
        rnn_cell = GRU
    elif rnn_type == 'SimpleRNN':
        rnn_cell = SimpleRNN
    else:
        raise ValueError('Option rnn_type could only be configured as LSTM, GRU or SimpleRNN')
    model.add(rnn_cell(hidden_dim,return_sequences=True,batch_input_shape=(batch_size,time_step,input_dim)))

    for _ in range(rnn_layer_num-2):
        model.add(rnn_cell(hidden_dim, return_sequence=True))
        # prevent over fitting
        model.add(Dropout(dropout))



    model.add(rnn_cell(hidden_dim,return_sequences=False))

    # add hidden layer

    for _ in range(hidden_num):
        model.add(Dense(hidden_dim))

    model.add(Dropout(dropout))

    model.add(Dense(output_dim))

    rmsprop = RMSprop(lr=0.01)
    adam = Adam(lr=0.01)


    model.compile(loss='mse',metrics=['acc'],optimizer=rmsprop)

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号