def build_multi_input_main_residual_network(batch_size,
a2_time_step,
d2_time_step,
d1_time_step,
input_dim,
output_dim,
loop_depth=15,
dropout=0.3):
'''
a multiple residual network for wavelet transformation
:param batch_size: as you might see
:param a2_time_step: a2_size
:param d2_time_step: d2_size
:param d1_time_step: d1_size
:param input_dim: input_dim
:param output_dim: output_dim
:param loop_depth: depth of residual network
:param dropout: rate of dropout
:return:
'''
a2_inp = Input(shape=(a2_time_step,input_dim),name='a2')
d2_inp = Input(shape=(d2_time_step,input_dim),name='d2')
d1_inp = Input(shape=(d1_time_step,input_dim),name='a1')
out = concatenate([a2_inp,d2_inp,d1_inp],axis=1)
out = Conv1D(128,5)(out)
out = BatchNormalization()(out)
out = Activation('relu')(out)
out = first_block(out,(64,128),dropout=dropout)
for _ in range(loop_depth):
out = repeated_block(out,(64,128),dropout=dropout)
# add flatten
out = Flatten()(out)
out = BatchNormalization()(out)
out = Activation('relu')(out)
out = Dense(output_dim)(out)
model = Model(inputs=[a2_inp,d2_inp,d1_inp],outputs=[out])
model.compile(loss='mse',optimizer='adam',metrics=['mse','mae'])
return model
评论列表
文章目录