model_sentences.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:onto-lstm 作者: pdasigi 项目源码 文件源码
def train(self, S_ind, C_ind, use_onto_lstm=True, use_attention=True, num_epochs=20,  hierarchical=False, base=2):
    # Predict next word from current synsets
    X = C_ind[:,:-1] if use_onto_lstm else S_ind[:,:-1] # remove the last words' hyps in all sentences
    Y_inds = S_ind[:,1:] # remove the first words in all sentences
    if hierarchical:
      train_targets = self._factor_target_indices(Y_inds, base=base)
    else:
      train_targets = [self._make_one_hot(Y_inds, Y_inds.max() + 1)]
    length = Y_inds.shape[1]
    lstm_outdim = self.word_dim

    num_words = len(self.dp.word_index)
    num_syns = len(self.dp.synset_index)
    input = Input(shape=X.shape[1:], dtype='int32')
    embed_input_dim = num_syns if use_onto_lstm else num_words
    embed_layer = HigherOrderEmbedding(name='embedding', input_dim=embed_input_dim, output_dim=self.word_dim, input_shape=X.shape[1:], mask_zero=True)
    sent_rep = embed_layer(input)
    reg_sent_rep = Dropout(0.5)(sent_rep)
    if use_onto_lstm:
      lstm_out = OntoAttentionLSTM(name='sent_lstm', input_dim=self.word_dim, output_dim=lstm_outdim, input_length=length, num_senses=self.num_senses, num_hyps=self.num_hyps, return_sequences=True, use_attention=use_attention)(reg_sent_rep)
    else:
      lstm_out = LSTM(name='sent_lstm', input_dim=self.word_dim, output_dim=lstm_outdim, input_length=length, return_sequences=True)(reg_sent_rep)
    output_nodes = []
    # Make one node for each factored target
    for target in train_targets:
      node = TimeDistributed(Dense(input_dim=lstm_outdim, output_dim=target.shape[-1], activation='softmax'))(lstm_out)
      output_nodes.append(node)

    model = Model(input=input, output=output_nodes)
    print >>sys.stderr, model.summary()
    early_stopping = EarlyStopping()
    precompile_time = time.time()
    model.compile(loss='categorical_crossentropy', optimizer='adam')
    postcompile_time = time.time()
    print >>sys.stderr, "Model compilation took %d s"%(postcompile_time - precompile_time)
    model.fit(X, train_targets, nb_epoch=num_epochs, validation_split=0.1, callbacks=[early_stopping])
    posttrain_time = time.time()
    print >>sys.stderr, "Training took %d s"%(posttrain_time - postcompile_time)
    concept_reps = model.layers[1].get_weights()
    self.model = model
    return concept_reps
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号