model.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:keras_detect_tool_wear 作者: kidozh 项目源码 文件源码
def build_model(timestep,input_dim,output_dim,dropout=0.5,recurrent_layers_num=4,cnn_layers_num=6,lr=0.001):
    inp = Input(shape=(timestep,input_dim))
    output = TimeDistributed(Masking(mask_value=0))(inp)
    #output = inp
    output = Conv1D(128, 1)(output)
    output = BatchNormalization()(output)
    output = Activation('relu')(output)

    output = first_block(output, (64, 128), dropout=dropout)


    output = Dropout(dropout)(output)
    for _ in range(cnn_layers_num):
        output = repeated_block(output, (64, 128), dropout=dropout)

    output = Flatten()(output)
    #output = LSTM(128, return_sequences=False)(output)

    output = BatchNormalization()(output)
    output = Activation('relu')(output)
    output = Dense(output_dim)(output)


    model = Model(inp,output)

    optimizer = Adam(lr=lr)

    model.compile(optimizer,'mse',['mae'])
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号