qa_model.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:Question-Answering 作者: MurtyShikhar 项目源码 文件源码
def run_lstm(self, encoded_rep, q_rep, masks):
        encoded_question, encoded_passage = encoded_rep
        masks_question, masks_passage = masks

        q_rep = tf.expand_dims(q_rep, 1) # (batch_size, 1, D)
        encoded_passage_shape = tf.shape(encoded_passage)[1]
        q_rep = tf.tile(q_rep, [1, encoded_passage_shape, 1])

        mixed_question_passage_rep = tf.concat([encoded_passage, q_rep], axis=-1)

        with tf.variable_scope("lstm_"):
            cell = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)
            reverse_mixed_question_passage_rep = _reverse(mixed_question_passage_rep, masks_passage, 1, 0)

            output_attender_fw, _ = tf.nn.dynamic_rnn(cell, mixed_question_passage_rep, dtype=tf.float32, scope ="rnn")    
            output_attender_bw, _ = tf.nn.dynamic_rnn(cell, reverse_mixed_question_passage_rep, dtype=tf.float32, scope = "rnn")

            output_attender_bw = _reverse(output_attender_bw, masks_passage, 1, 0)


        output_attender = tf.concat([output_attender_fw, output_attender_bw], axis = -1) # (-1, P, 2*H)
        return output_attender
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号