def run_lstm(self, encoded_rep, q_rep, masks):
encoded_question, encoded_passage = encoded_rep
masks_question, masks_passage = masks
q_rep = tf.expand_dims(q_rep, 1) # (batch_size, 1, D)
encoded_passage_shape = tf.shape(encoded_passage)[1]
q_rep = tf.tile(q_rep, [1, encoded_passage_shape, 1])
mixed_question_passage_rep = tf.concat([encoded_passage, q_rep], axis=-1)
with tf.variable_scope("lstm_"):
cell = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)
reverse_mixed_question_passage_rep = _reverse(mixed_question_passage_rep, masks_passage, 1, 0)
output_attender_fw, _ = tf.nn.dynamic_rnn(cell, mixed_question_passage_rep, dtype=tf.float32, scope ="rnn")
output_attender_bw, _ = tf.nn.dynamic_rnn(cell, reverse_mixed_question_passage_rep, dtype=tf.float32, scope = "rnn")
output_attender_bw = _reverse(output_attender_bw, masks_passage, 1, 0)
output_attender = tf.concat([output_attender_fw, output_attender_bw], axis = -1) # (-1, P, 2*H)
return output_attender
评论列表
文章目录