def questionLSTM(self, q, q_real_len, reuse = False, scope= "questionLSTM"):
"""
Args
q: zero padded qeustions, shape=[batch_size, q_max_len]
q_real_len: original question length, shape = [batch_size, 1]
Returns
embedded_q: embedded questions, shape = [batch_size, q_hidden(32)]
"""
embedded_q_word = tf.nn.embedding_lookup(self.q_word_embed_matrix, q)
q_input = tf.unstack(embedded_q_word, num = self.q_max_len, axis=1)
lstm_cell = rnn.BasicLSTMCell(self.q_hidden, reuse = reuse)
outputs, _ = rnn.static_rnn(lstm_cell, q_input, dtype = tf.float32, scope = scope)
outputs = tf.stack(outputs)
outputs = tf.transpose(outputs, [1,0,2])
index = tf.range(0, self.batch_size) * (self.q_max_len) + (q_real_len - 1)
outputs = tf.gather(tf.reshape(outputs, [-1, self.s_hidden]), index)
return outputs
评论列表
文章目录