def trainCNN(obj, dataset_headLines, dataset_body):
embedding_dim = 300
LSTM_neurons = 50
dense_neuron = 16
dimx = 100
dimy = 200
lamda = 0.0
nb_filter = 100
filter_length = 4
vocab_size = 10000
batch_size = 50
epochs = 5
ntn_out = 16
ntn_in = nb_filter
state = False
train_head,train_body,embedding_matrix = obj.process_data(sent_Q=dataset_headLines,
sent_A=dataset_body,dimx=dimx,dimy=dimy,
wordVec_model = wordVec_model)
inpx = Input(shape=(dimx,),dtype='int32',name='inpx')
#x = Embedding(output_dim=embedding_dim, input_dim=vocab_size, input_length=dimx)(inpx)
x = word2vec_embedding_layer(embedding_matrix)(inpx)
inpy = Input(shape=(dimy,),dtype='int32',name='inpy')
#y = Embedding(output_dim=embedding_dim, input_dim=vocab_size, input_length=dimy)(inpy)
y = word2vec_embedding_layer(embedding_matrix)(inpy)
ques = Convolution1D(nb_filter=nb_filter, filter_length=filter_length,
border_mode='valid', activation='relu',
subsample_length=1)(x)
ans = Convolution1D(nb_filter=nb_filter, filter_length=filter_length,
border_mode='valid', activation='relu',
subsample_length=1)(y)
#hx = Lambda(max_1d, output_shape=(nb_filter,))(ques)
#hy = Lambda(max_1d, output_shape=(nb_filter,))(ans)
hx = GlobalMaxPooling1D()(ques)
hy = GlobalMaxPooling1D()(ans)
#wordVec_model = []
#h = Merge(mode="concat",name='h')([hx,hy])
h1 = Multiply()([hx,hy])
h2 = Abs()([hx,hy])
h = Merge(mode="concat",name='h')([h1,h2])
#h = NeuralTensorLayer(output_dim=1,input_dim=ntn_in)([hx,hy])
#h = ntn_layer(ntn_in,ntn_out,activation=None)([hx,hy])
#score = h
wrap = Dense(dense_neuron, activation='relu',name='wrap')(h)
#score = Dense(1,activation='sigmoid',name='score')(h)
#wrap = Dense(dense_neuron,activation='relu',name='wrap')(h)
score = Dense(4,activation='softmax',name='score')(wrap)
#score=K.clip(score,1e-7,1.0-1e-7)
#corr = CorrelationRegularization(-lamda)([hx,hy])
#model = Model( [inpx,inpy],[score,corr])
model = Model( [inpx,inpy],score)
model.compile( loss='categorical_crossentropy',optimizer="adadelta",metrics=['accuracy'])
return model,train_head,train_body
评论列表
文章目录