model.py 文件源码

python
阅读 15 收藏 0 点赞 0 评论 0

项目:keras_detect_tool_wear 作者: kidozh 项目源码 文件源码
def build_multi_input_main_residual_network(batch_size,
                                a2_time_step,
                                d2_time_step,
                                d1_time_step,
                                input_dim,
                                output_dim,
                                loop_depth=15,
                                dropout=0.3):
    '''
    a multiple residual network for wavelet transformation
    :param batch_size: as you might see
    :param a2_time_step: a2_size
    :param d2_time_step: d2_size
    :param d1_time_step: d1_size
    :param input_dim: input_dim
    :param output_dim: output_dim
    :param loop_depth: depth of residual network
    :param dropout: rate of dropout
    :return: 
    '''
    a2_inp = Input(shape=(a2_time_step,input_dim),name='a2')
    d2_inp = Input(shape=(d2_time_step,input_dim),name='d2')
    d1_inp = Input(shape=(d1_time_step,input_dim),name='a1')

    out = concatenate([a2_inp,d2_inp,d1_inp],axis=1)



    out = Conv1D(128,5)(out)
    out = BatchNormalization()(out)
    out = Activation('relu')(out)

    out = first_block(out,(64,128),dropout=dropout)

    for _ in range(loop_depth):
        out = repeated_block(out,(64,128),dropout=dropout)

    # add flatten
    out = Flatten()(out)

    out = BatchNormalization()(out)
    out = Activation('relu')(out)
    out = Dense(output_dim)(out)

    model = Model(inputs=[a2_inp,d2_inp,d1_inp],outputs=[out])

    model.compile(loss='mse',optimizer='adam',metrics=['mse','mae'])
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号