def _build_rnn_encoder(self, sentence1, sentence2,
sentence1_lengths, sentence2_lengths):
with tf.variable_scope('word_embedding'):
sentence1_embedding = tf.nn.embedding_lookup(self._word_embedding, sentence1)
sentence2_embedding = tf.nn.embedding_lookup(self._word_embedding, sentence2)
with tf.variable_scope('rnn'):
def _run_birnn(fw_cell, bw_cell, inputs, lengths):
(fw_output, bw_output), (fw_final_state, bw_final_state) =\
tf.nn.bidirectional_dynamic_rnn(
fw_cell, bw_cell,
inputs,
sequence_length=lengths,
time_major=False,
dtype=tf.float32
)
output = tf.concat([fw_output, bw_output], 2)
state = tf.concat([fw_final_state, bw_final_state], 1)
return output, state
state_size = self.config['rnn']['state_size']
forward_cell = GRUCell(state_size)
backward_cell = GRUCell(state_size)
sentence1_rnned, _ = _run_birnn(forward_cell, backward_cell,
sentence1_embedding, sentence1_lengths)
sentence2_rnned, _ = _run_birnn(forward_cell, backward_cell,
sentence2_embedding, sentence2_lengths)
return sentence1_embedding, sentence2_embedding, \
sentence1_rnned, sentence2_rnned
decom_classification.py 文件源码
python
阅读 16
收藏 0
点赞 0
评论 0
评论列表
文章目录