def encode_sentences(self, text_emb, text_len, text_len_mask):
num_sentences = tf.shape(text_emb)[0]
max_sentence_length = tf.shape(text_emb)[1]
# Transpose before and after for efficiency.
inputs = tf.transpose(text_emb, [1, 0, 2]) # [max_sentence_length, num_sentences, emb]
with tf.variable_scope("fw_cell"):
cell_fw = util.CustomLSTMCell(self.config["lstm_size"], num_sentences, self.dropout)
preprocessed_inputs_fw = cell_fw.preprocess_input(inputs)
with tf.variable_scope("bw_cell"):
cell_bw = util.CustomLSTMCell(self.config["lstm_size"], num_sentences, self.dropout)
preprocessed_inputs_bw = cell_bw.preprocess_input(inputs)
preprocessed_inputs_bw = tf.reverse_sequence(preprocessed_inputs_bw,
seq_lengths=text_len,
seq_dim=0,
batch_dim=1)
state_fw = tf.contrib.rnn.LSTMStateTuple(tf.tile(cell_fw.initial_state.c, [num_sentences, 1]), tf.tile(cell_fw.initial_state.h, [num_sentences, 1]))
state_bw = tf.contrib.rnn.LSTMStateTuple(tf.tile(cell_bw.initial_state.c, [num_sentences, 1]), tf.tile(cell_bw.initial_state.h, [num_sentences, 1]))
with tf.variable_scope("lstm"):
with tf.variable_scope("fw_lstm"):
fw_outputs, fw_states = tf.nn.dynamic_rnn(cell=cell_fw,
inputs=preprocessed_inputs_fw,
sequence_length=text_len,
initial_state=state_fw,
time_major=True)
with tf.variable_scope("bw_lstm"):
bw_outputs, bw_states = tf.nn.dynamic_rnn(cell=cell_bw,
inputs=preprocessed_inputs_bw,
sequence_length=text_len,
initial_state=state_bw,
time_major=True)
bw_outputs = tf.reverse_sequence(bw_outputs,
seq_lengths=text_len,
seq_dim=0,
batch_dim=1)
text_outputs = tf.concat([fw_outputs, bw_outputs], 2)
text_outputs = tf.transpose(text_outputs, [1, 0, 2]) # [num_sentences, max_sentence_length, emb]
return self.flatten_emb_by_sentence(text_outputs, text_len_mask)
评论列表
文章目录