CorrMCNN_Arch2.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:DeepLearn 作者: GauravBh1010tt 项目源码 文件源码
def buildModel(loss_type,lamda):

    inpx = Input(shape=(dimx,))
    inpy = Input(shape=(dimy,))

    hx = Reshape((28,14,1))(inpx)
    hx = Conv2D(128, (3, 3), activation='relu', padding='same')(hx)
    hx = MaxPooling2D((2, 2), padding='same')(hx)
    hx = Conv2D(64, (3, 3), activation='relu', padding='same')(hx)
    hx = MaxPooling2D((2, 2), padding='same')(hx)
    hx = Conv2D(49, (3, 3), activation='relu', padding='same')(hx)
    hx = MaxPooling2D((2, 2), padding='same')(hx)
    hx = Flatten()(hx)
    hx1 = Dense(hdim_deep,activation='sigmoid')(hx)
    hx2 = Dense(hdim_deep2, activation='sigmoid',name='hid_l1')(hx1)
    hx = Dense(hdim, activation='sigmoid',name='hid_l')(hx2)

    hy = Reshape((28,14,1))(inpy)
    hy = Conv2D(128, (3, 3), activation='relu', padding='same')(hy)
    hy = MaxPooling2D((2, 2), padding='same')(hy)
    hy = Conv2D(64, (3, 3), activation='relu', padding='same')(hy)
    hy = MaxPooling2D((2, 2), padding='same')(hy)
    hy = Conv2D(49, (3, 3), activation='relu', padding='same')(hy)
    hy = MaxPooling2D((2, 2), padding='same')(hy)
    hy = Flatten()(hy)
    hy1 = Dense(hdim_deep,activation='sigmoid')(hy)
    hy2 = Dense(hdim_deep2, activation='sigmoid',name='hid_r1')(hy1)
    hy = Dense(hdim, activation='sigmoid',name='hid_r')(hy2)

    h =  Merge(mode="sum")([hx,hy]) 

    recx = Dense(dimx)(h)
    recy = Dense(dimy)(h)

    branchModel = Model( [inpx,inpy],[recx,recy,h,hx1,hy1,hx2,hy2])

    [recx1,recy1,h1,_,_,_,_] = branchModel( [inpx, ZeroPadding()(inpy)])
    [recx2,recy2,h2,_,_,_,_] = branchModel( [ZeroPadding()(inpx), inpy ])

    #you may probably add a reconstruction from combined
    [recx3,recy3,h3,hx_1,hy_1,hx_2,hy_2] = branchModel([inpx, inpy])

    lamda2,lamda3 = 0.001,0.05

    corr1=CorrnetCost(-lamda)([h1,h2])
    corr2=CorrnetCost(-lamda2)([hx_1,hy_1])
    corr3=CorrnetCost(-lamda3)([hx_2,hy_2])

    model = Model( [inpx,inpy],[recy1,recx2,recx1,recy2,corr1,corr2,corr3])
    model.compile( loss=["mse","mse","mse","mse",corr_loss,corr_loss,corr_loss],optimizer="rmsprop")

    return model, branchModel
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号