def train(self, S_ind, C_ind, use_onto_lstm=True, use_attention=True, num_epochs=20, hierarchical=False, base=2):
# Predict next word from current synsets
X = C_ind[:,:-1] if use_onto_lstm else S_ind[:,:-1] # remove the last words' hyps in all sentences
Y_inds = S_ind[:,1:] # remove the first words in all sentences
if hierarchical:
train_targets = self._factor_target_indices(Y_inds, base=base)
else:
train_targets = [self._make_one_hot(Y_inds, Y_inds.max() + 1)]
length = Y_inds.shape[1]
lstm_outdim = self.word_dim
num_words = len(self.dp.word_index)
num_syns = len(self.dp.synset_index)
input = Input(shape=X.shape[1:], dtype='int32')
embed_input_dim = num_syns if use_onto_lstm else num_words
embed_layer = HigherOrderEmbedding(name='embedding', input_dim=embed_input_dim, output_dim=self.word_dim, input_shape=X.shape[1:], mask_zero=True)
sent_rep = embed_layer(input)
reg_sent_rep = Dropout(0.5)(sent_rep)
if use_onto_lstm:
lstm_out = OntoAttentionLSTM(name='sent_lstm', input_dim=self.word_dim, output_dim=lstm_outdim, input_length=length, num_senses=self.num_senses, num_hyps=self.num_hyps, return_sequences=True, use_attention=use_attention)(reg_sent_rep)
else:
lstm_out = LSTM(name='sent_lstm', input_dim=self.word_dim, output_dim=lstm_outdim, input_length=length, return_sequences=True)(reg_sent_rep)
output_nodes = []
# Make one node for each factored target
for target in train_targets:
node = TimeDistributed(Dense(input_dim=lstm_outdim, output_dim=target.shape[-1], activation='softmax'))(lstm_out)
output_nodes.append(node)
model = Model(input=input, output=output_nodes)
print >>sys.stderr, model.summary()
early_stopping = EarlyStopping()
precompile_time = time.time()
model.compile(loss='categorical_crossentropy', optimizer='adam')
postcompile_time = time.time()
print >>sys.stderr, "Model compilation took %d s"%(postcompile_time - precompile_time)
model.fit(X, train_targets, nb_epoch=num_epochs, validation_split=0.1, callbacks=[early_stopping])
posttrain_time = time.time()
print >>sys.stderr, "Training took %d s"%(posttrain_time - postcompile_time)
concept_reps = model.layers[1].get_weights()
self.model = model
return concept_reps
评论列表
文章目录