decom_classification.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:decomposable_attention 作者: shuuki4 项目源码 文件源码
def _build_rnn_encoder(self, sentence1, sentence2,
                           sentence1_lengths, sentence2_lengths):

        with tf.variable_scope('word_embedding'):
            sentence1_embedding = tf.nn.embedding_lookup(self._word_embedding, sentence1)
            sentence2_embedding = tf.nn.embedding_lookup(self._word_embedding, sentence2)

        with tf.variable_scope('rnn'):
            def _run_birnn(fw_cell, bw_cell, inputs, lengths):
                (fw_output, bw_output), (fw_final_state, bw_final_state) =\
                    tf.nn.bidirectional_dynamic_rnn(
                        fw_cell, bw_cell,
                        inputs,
                        sequence_length=lengths,
                        time_major=False,
                        dtype=tf.float32
                    )

                output = tf.concat([fw_output, bw_output], 2)
                state = tf.concat([fw_final_state, bw_final_state], 1)
                return output, state

            state_size = self.config['rnn']['state_size']
            forward_cell = GRUCell(state_size)
            backward_cell = GRUCell(state_size)

            sentence1_rnned, _ = _run_birnn(forward_cell, backward_cell,
                                            sentence1_embedding, sentence1_lengths)
            sentence2_rnned, _ = _run_birnn(forward_cell, backward_cell,
                                            sentence2_embedding, sentence2_lengths)

        return sentence1_embedding, sentence2_embedding, \
               sentence1_rnned, sentence2_rnned
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号