def generate_rnn_output(self):
"""
Generate RNN state outputs with word embeddings as inputs
"""
with tf.variable_scope("generate_seq_output"):
if self.bidirectional_rnn:
embedding = tf.get_variable("embedding",
[self.source_vocab_size,
self.word_embedding_size])
encoder_emb_inputs = list()
encoder_emb_inputs = [tf.nn.embedding_lookup(embedding, encoder_input)\
for encoder_input in self.encoder_inputs]
rnn_outputs = static_bidirectional_rnn(self.cell_fw,
self.cell_bw,
encoder_emb_inputs,
sequence_length=self.sequence_length,
dtype=tf.float32)
encoder_outputs, encoder_state_fw, encoder_state_bw = rnn_outputs
# with state_is_tuple = True, if num_layers > 1,
# here we simply use the state from last layer as the encoder state
state_fw = encoder_state_fw[-1]
state_bw = encoder_state_bw[-1]
encoder_state = tf.concat([tf.concat(state_fw, 1),
tf.concat(state_bw, 1)], 1)
top_states = [tf.reshape(e, [-1, 1, self.cell_fw.output_size \
+ self.cell_bw.output_size])
for e in encoder_outputs]
attention_states = tf.concat(top_states, 1)
else:
embedding = tf.get_variable("embedding",
[self.source_vocab_size,
self.word_embedding_size])
encoder_emb_inputs = list()
encoder_emb_inputs = [tf.nn.embedding_lookup(embedding, encoder_input)\
for encoder_input in self.encoder_inputs]
rnn_outputs = static_rnn(self.cell_fw,
encoder_emb_inputs,
sequence_length=self.sequence_length,
dtype=tf.float32)
encoder_outputs, encoder_state = rnn_outputs
# with state_is_tuple = True, if num_layers > 1,
# here we use the state from last layer as the encoder state
state = encoder_state[-1]
encoder_state = tf.concat(state, 1)
top_states = [tf.reshape(e, [-1, 1, self.cell_fw.output_size])
for e in encoder_outputs]
attention_states = tf.concat(top_states, 1)
return encoder_outputs, encoder_state, attention_states
评论列表
文章目录