residual_model.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:keras_detect_tool_wear 作者: kidozh 项目源码 文件源码
def build_2d_main_residual_network(batch_size,
                                width,
                                height,
                                channel_size,
                                output_dim,
                                loop_depth=15,
                                dropout=0.3):
    inp = Input(shape=(width,height,channel_size))

    # add mask for filter invalid data
    out = TimeDistributed(Masking(mask_value=0))(inp)


    out = Conv2D(128,5,data_format='channels_last')(out)
    out = BatchNormalization()(out)
    out = Activation('relu')(out)

    out = first_2d_block(out,(64,128),dropout=dropout)

    for _ in range(loop_depth):
        out = repeated_2d_block(out,(64,128),dropout=dropout)

    # add flatten
    out = Flatten()(out)

    out = BatchNormalization()(out)
    out = Activation('relu')(out)
    out = Dense(output_dim)(out)

    model = Model(inp,out)

    model.compile(loss='mse',optimizer='adam',metrics=['mse','mae'])
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号