def input_encoder_bi_lstm(self):
"""use bi-directional lstm to encode query_embedding:[batch_size,sequence_length,embed_size]
and story_embedding:[batch_size,story_length,sequence_length,embed_size]
output:query_embedding:[batch_size,hidden_size*2] story_embedding:[batch_size,self.story_length,self.hidden_size*2]
"""
#1. encode query: bi-lstm layer
lstm_fw_cell = rnn.BasicLSTMCell(self.hidden_size) # forward direction cell
lstm_bw_cell = rnn.BasicLSTMCell(self.hidden_size) # backward direction cell
if self.dropout_keep_prob is not None:
lstm_fw_cell = rnn.DropoutWrapper(lstm_fw_cell, output_keep_prob=self.dropout_keep_prob)
lstm_bw_cell == rnn.DropoutWrapper(lstm_bw_cell, output_keep_prob=self.dropout_keep_prob)
query_hidden_output, _ = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, self.query_embedding,dtype=tf.float32,scope="query_rnn") # [batch_size,sequence_length,hidden_size] #creates a dynamic bidirectional recurrent neural network
query_hidden_output = tf.concat(query_hidden_output, axis=2) #[batch_size,sequence_length,hidden_size*2]
self.query_embedding=tf.reduce_sum(query_hidden_output,axis=1) #[batch_size,hidden_size*2]
print("input_encoder_bi_lstm.self.query_embedding:",self.query_embedding)
#2. encode story
# self.story_embedding:[batch_size,story_length,sequence_length,embed_size]
self.story_embedding=tf.reshape(self.story_embedding,shape=(-1,self.story_length*self.sequence_length,self.embed_size)) #[self.story_length*self.sequence_length,self.embed_size]
lstm_fw_cell_story = rnn.BasicLSTMCell(self.hidden_size) # forward direction cell
lstm_bw_cell_story = rnn.BasicLSTMCell(self.hidden_size) # backward direction cell
if self.dropout_keep_prob is not None:
lstm_fw_cell_story = rnn.DropoutWrapper(lstm_fw_cell_story, output_keep_prob=self.dropout_keep_prob)
lstm_bw_cell_story == rnn.DropoutWrapper(lstm_bw_cell_story, output_keep_prob=self.dropout_keep_prob)
story_hidden_output, _ = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell_story, lstm_bw_cell_story, self.story_embedding,dtype=tf.float32,scope="story_rnn")
story_hidden_output=tf.concat(story_hidden_output,axis=2) #[batch_size,story_length*sequence_length,hidden_size*2]
story_hidden_output=tf.reshape(story_hidden_output,shape=(-1,self.story_length,self.sequence_length,self.hidden_size*2))
self.story_embedding = tf.reduce_sum(story_hidden_output, axis=2) # [batch_size,self.story_length,self.hidden_size*2]
a3_entity_network.py 文件源码
python
阅读 33
收藏 0
点赞 0
评论 0
评论列表
文章目录