def _embed_sentences(self):
"""Embed sentences via the last output cell of an LSTM"""
word_embeddings = self._get_embedding()
word_feats = tf.nn.embedding_lookup(word_embeddings, self.input)
batch_size = tf.shape(self.input)[0]
with tf.variable_scope("LSTM") as scope:
tf.set_random_seed(self.seed - 1)
# LSTM architecture
cell = tf.contrib.rnn.BasicLSTMCell(self.d)
# Set RNN
initial_state = cell.zero_state(batch_size, tf.float32)
rnn_out, _ = tf.nn.dynamic_rnn(
cell, word_feats, sequence_length=self.input_lengths,
initial_state=initial_state, time_major=False
)
# Get potentials
return get_rnn_output(rnn_out, self.d, self.input_lengths), {}
评论列表
文章目录