def create_model(data):
'''
Load keras model.
'''
# Entity branch
entity_inputs = Input(shape=(data[0].shape[1],))
entity_x = Dense(data[0].shape[1], activation='relu',
kernel_constraint=maxnorm(3))(entity_inputs)
entity_x = Dropout(0.25)(entity_x)
#entity_x = Dense(50, activation='relu',
# kernel_constraint=maxnorm(3))(entity_x)
#entity_x = Dropout(0.25)(entity_x)
# Candidate branch
candidate_inputs = Input(shape=(data[1].shape[1],))
candidate_x = Dense(data[1].shape[1], activation='relu',
kernel_constraint=maxnorm(3))(candidate_inputs)
candidate_x = Dropout(0.25)(candidate_x)
#candidate_x = Dense(50, activation='relu',
# kernel_constraint=maxnorm(3))(candidate_x)
#candidate_x = Dropout(0.25)(candidate_x)
# Cosine proximity
# cos_x = dot([entity_x, candidate_x], axes=1, normalize=False)
# cos_x = concatenate([entity_x, candidate_x])
# cos_output = Dense(1, activation='sigmoid')(cos_x)
# Match branch
match_inputs = Input(shape=(data[2].shape[1],))
match_x = Dense(data[1].shape[1], activation='relu',
kernel_constraint=maxnorm(3))(match_inputs)
match_x = Dropout(0.25)(match_x)
# Merge
x = concatenate([entity_x, candidate_x, match_x])
x = Dense(32, activation='relu', kernel_constraint=maxnorm(3))(x)
x = Dropout(0.25)(x)
x = Dense(16, activation='relu', kernel_constraint=maxnorm(3))(x)
x = Dropout(0.25)(x)
x = Dense(8, activation='relu', kernel_constraint=maxnorm(3))(x)
x = Dropout(0.25)(x)
predictions = Dense(1, activation='sigmoid')(x)
model = Model(inputs=[entity_inputs, candidate_inputs, match_inputs],
outputs=predictions)
model.compile(optimizer='RMSprop', loss='binary_crossentropy',
metrics=['accuracy'])
return model
评论列表
文章目录