def buildModel(loss_type,lamda):
inpx = Input(shape=(dimx,))
inpy = Input(shape=(dimy,))
hx = Dense(hdim_deep,activation='sigmoid')(inpx)
hx = Dense(hdim_deep2, activation='sigmoid',name='hid_l1')(hx)
hx = Dense(hdim, activation='sigmoid',name='hid_l')(hx)
hy = Dense(hdim_deep,activation='sigmoid')(inpy)
hy = Dense(hdim_deep2, activation='sigmoid',name='hid_r1')(hy)
hy = Dense(hdim, activation='sigmoid',name='hid_r')(hy)
#h = Activation("sigmoid")( Merge(mode="sum")([hx,hy]) )
h = Merge(mode="sum")([hx,hy])
#recx = Dense(hdim_deep,activation='sigmoid')(h)
recx = Dense(dimx)(h)
#recy = Dense(hdim_deep,activation='sigmoid')(h)
recy = Dense(dimy)(h)
branchModel = Model( [inpx,inpy],[recx,recy,h])
#inpx = Input(shape=(dimx,))
#inpy = Input(shape=(dimy,))
[recx1,recy1,h1] = branchModel( [inpx, ZeroPadding()(inpy)])
[recx2,recy2,h2] = branchModel( [ZeroPadding()(inpx), inpy ])
#you may probably add a reconstruction from combined
[recx3,recy3,h] = branchModel([inpx, inpy])
corr=CorrnetCost(-lamda)([h1,h2])
if loss_type == 1:
model = Model( [inpx,inpy],[recy1,recx2,recx3,recx1,recy2,recy3,corr])
model.compile( loss=["mse","mse","mse","mse","mse","mse",corr_loss],optimizer="rmsprop")
elif loss_type == 2:
model = Model( [inpx,inpy],[recy1,recx2,recx1,recy2,corr])
model.compile( loss=["mse","mse","mse","mse",corr_loss],optimizer="rmsprop")
elif loss_type == 3:
model = Model( [inpx,inpy],[recy1,recx2,recx3,recx1,recy2,recy3])
model.compile( loss=["mse","mse","mse","mse","mse","mse"],optimizer="rmsprop")
elif loss_type == 4:
model = Model( [inpx,inpy],[recy1,recx2,recx1,recy2])
model.compile( loss=["mse","mse","mse","mse"],optimizer="rmsprop")
return model, branchModel
评论列表
文章目录