bnn.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:dac-training 作者: jlonij 项目源码 文件源码
def create_model(data):
    '''
    Load keras model.
    '''
    # Entity branch
    entity_inputs = Input(shape=(data[0].shape[1],))
    entity_x = Dense(data[0].shape[1], activation='relu',
            kernel_constraint=maxnorm(3))(entity_inputs)
    entity_x = Dropout(0.25)(entity_x)
    #entity_x = Dense(50, activation='relu',
    #        kernel_constraint=maxnorm(3))(entity_x)
    #entity_x = Dropout(0.25)(entity_x)

    # Candidate branch
    candidate_inputs = Input(shape=(data[1].shape[1],))
    candidate_x = Dense(data[1].shape[1], activation='relu',
            kernel_constraint=maxnorm(3))(candidate_inputs)
    candidate_x = Dropout(0.25)(candidate_x)
    #candidate_x = Dense(50, activation='relu',
    #        kernel_constraint=maxnorm(3))(candidate_x)
    #candidate_x = Dropout(0.25)(candidate_x)

    # Cosine proximity
    # cos_x = dot([entity_x, candidate_x], axes=1, normalize=False)
    # cos_x = concatenate([entity_x, candidate_x])
    # cos_output = Dense(1, activation='sigmoid')(cos_x)

    # Match branch
    match_inputs = Input(shape=(data[2].shape[1],))
    match_x = Dense(data[1].shape[1], activation='relu',
            kernel_constraint=maxnorm(3))(match_inputs)
    match_x = Dropout(0.25)(match_x)

    # Merge
    x = concatenate([entity_x, candidate_x, match_x])
    x = Dense(32, activation='relu', kernel_constraint=maxnorm(3))(x)
    x = Dropout(0.25)(x)
    x = Dense(16, activation='relu', kernel_constraint=maxnorm(3))(x)
    x = Dropout(0.25)(x)
    x = Dense(8, activation='relu', kernel_constraint=maxnorm(3))(x)
    x = Dropout(0.25)(x)

    predictions = Dense(1, activation='sigmoid')(x)

    model = Model(inputs=[entity_inputs, candidate_inputs, match_inputs],
        outputs=predictions)
    model.compile(optimizer='RMSprop', loss='binary_crossentropy',
        metrics=['accuracy'])

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号