def _add_encoders(self):
with tf.variable_scope('query_encoder'):
query_encoder_cell = GRUCell(self.encoder_cell_state_size)
if self.dropout_enabled and self.mode != 'decode':
query_encoder_cell = DropoutWrapper(cell=query_encoder_cell, output_keep_prob=0.8)
query_embeddings = tf.nn.embedding_lookup(self.embeddings, self.queries_placeholder)
query_encoder_outputs, _ = rnn.dynamic_rnn(query_encoder_cell, query_embeddings,
sequence_length=self.query_lengths_placeholder,
swap_memory=True, dtype=tf.float32)
self.query_last = query_encoder_outputs[:, -1, :]
with tf.variable_scope('encoder'):
fw_cell = GRUCell(self.encoder_cell_state_size)
bw_cell = GRUCell(self.encoder_cell_state_size)
if self.dropout_enabled and self.mode != 'decode':
fw_cell = DropoutWrapper(cell=fw_cell, output_keep_prob=0.8)
bw_cell = DropoutWrapper(cell=bw_cell, output_keep_prob=0.8)
embeddings = tf.nn.embedding_lookup(self.embeddings, self.documents_placeholder)
(encoder_outputs_fw, encoder_outputs_bw), _ = rnn.bidirectional_dynamic_rnn(
fw_cell, bw_cell,
embeddings,
sequence_length=self.document_lengths_placeholder,
swap_memory=True,
dtype=tf.float32)
self.encoder_outputs = tf.concat([encoder_outputs_fw, encoder_outputs_bw], 2)
self.final_encoder_state = self.encoder_outputs[:, -1, :]
评论列表
文章目录