def _build_rnn_encoder(self, sentence1, sentence2_pos, sentence2_neg,
sentence1_lengths, sentence2_pos_lengths, sentence2_neg_lengths):
with tf.variable_scope('word_embedding'):
sentence1_embedding = tf.nn.embedding_lookup(self._word_embedding, sentence1)
sentence2_pos_embedding = tf.nn.embedding_lookup(self._word_embedding, sentence2_pos)
sentence2_neg_embedding = tf.nn.embedding_lookup(self._word_embedding, sentence2_neg)
with tf.variable_scope('rnn'):
def _run_birnn(fw_cell, bw_cell, inputs, lengths):
(fw_output, bw_output), (fw_final_state, bw_final_state) =\
tf.nn.bidirectional_dynamic_rnn(
fw_cell, bw_cell,
inputs,
sequence_length=lengths,
time_major=False,
dtype=tf.float32
)
output = tf.concat([fw_output, bw_output], 2)
state = tf.concat([fw_final_state, bw_final_state], 1)
return output, state
state_size = self.config['rnn']['state_size']
forward_cell = GRUCell(state_size)
backward_cell = GRUCell(state_size)
sentence1_rnned, _ = _run_birnn(forward_cell, backward_cell,
sentence1_embedding, sentence1_lengths)
sentence2_rnned, _ = _run_birnn(
forward_cell, backward_cell,
tf.concat([sentence2_pos_embedding, sentence2_neg_embedding], 0),
tf.concat([sentence2_pos_lengths, sentence2_neg_lengths], 0))
sentence2_pos_rnned, sentence2_neg_rnned = \
tf.split(sentence2_rnned, num_or_size_splits=2, axis=0)
return sentence1_embedding, sentence2_pos_embedding, sentence2_neg_embedding, \
sentence1_rnned, sentence2_pos_rnned, sentence2_neg_rnned
评论列表
文章目录