cnn_tf_profile.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:DeepTFAS-in-D.mel 作者: mu102449 项目源码 文件源码
def dnn_model():
    '''
    Construct the DNN model,
    Flatten + 2*(Dense + relu) + Dense + softmax
    Tissue inputs will be inserted after the first Dense/relu layer, if activate
    '''
    print('Contruct DNN model')
    main_inputs = Input(shape=X_DATA[0].shape, name='sequence_inputs')
    hidden = Flatten()(main_inputs)
    hidden = Dense(128)(hidden)
    hidden = Activation('relu')(hidden)
    if ARGS.T:
        auxiliary_inputs = Input(shape=TISSUE_DATA[0].shape, name='tissue_inputs')
        hidden = keras.layers.concatenate([hidden, auxiliary_inputs])
    hidden = Dense(128)(hidden)
    hidden = Activation('relu')(hidden)
    outputs = Dense(CLASSES, activation='softmax')(hidden)
    if ARGS.T:
        model = Model(inputs=[main_inputs, auxiliary_inputs], outputs=outputs)
    else:
        model = Model(inputs=main_inputs, outputs=outputs)
    model.summary()

    model.compile(
        loss=keras.losses.categorical_crossentropy,
        optimizer=keras.optimizers.Adam(),
        metrics=['accuracy'])
    model.save_weights('{}model.h5~'.format(ARGS.o))
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号