a3_entity_network.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:text_classification 作者: brightmart 项目源码 文件源码
def input_encoder_bi_lstm(self):
        """use bi-directional lstm to encode query_embedding:[batch_size,sequence_length,embed_size]
                                         and story_embedding:[batch_size,story_length,sequence_length,embed_size]
        output:query_embedding:[batch_size,hidden_size*2]  story_embedding:[batch_size,self.story_length,self.hidden_size*2]
        """
        #1. encode query: bi-lstm layer
        lstm_fw_cell = rnn.BasicLSTMCell(self.hidden_size)  # forward direction cell
        lstm_bw_cell = rnn.BasicLSTMCell(self.hidden_size)  # backward direction cell
        if self.dropout_keep_prob is not None:
            lstm_fw_cell = rnn.DropoutWrapper(lstm_fw_cell, output_keep_prob=self.dropout_keep_prob)
            lstm_bw_cell == rnn.DropoutWrapper(lstm_bw_cell, output_keep_prob=self.dropout_keep_prob)
        query_hidden_output, _ = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, self.query_embedding,dtype=tf.float32,scope="query_rnn")  # [batch_size,sequence_length,hidden_size] #creates a dynamic bidirectional recurrent neural network
        query_hidden_output = tf.concat(query_hidden_output, axis=2) #[batch_size,sequence_length,hidden_size*2]
        self.query_embedding=tf.reduce_sum(query_hidden_output,axis=1) #[batch_size,hidden_size*2]
        print("input_encoder_bi_lstm.self.query_embedding:",self.query_embedding)

        #2. encode story
        # self.story_embedding:[batch_size,story_length,sequence_length,embed_size]
        self.story_embedding=tf.reshape(self.story_embedding,shape=(-1,self.story_length*self.sequence_length,self.embed_size)) #[self.story_length*self.sequence_length,self.embed_size]
        lstm_fw_cell_story = rnn.BasicLSTMCell(self.hidden_size)  # forward direction cell
        lstm_bw_cell_story = rnn.BasicLSTMCell(self.hidden_size)  # backward direction cell
        if self.dropout_keep_prob is not None:
            lstm_fw_cell_story = rnn.DropoutWrapper(lstm_fw_cell_story, output_keep_prob=self.dropout_keep_prob)
            lstm_bw_cell_story == rnn.DropoutWrapper(lstm_bw_cell_story, output_keep_prob=self.dropout_keep_prob)
        story_hidden_output, _ = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell_story, lstm_bw_cell_story, self.story_embedding,dtype=tf.float32,scope="story_rnn")
        story_hidden_output=tf.concat(story_hidden_output,axis=2) #[batch_size,story_length*sequence_length,hidden_size*2]
        story_hidden_output=tf.reshape(story_hidden_output,shape=(-1,self.story_length,self.sequence_length,self.hidden_size*2))
        self.story_embedding = tf.reduce_sum(story_hidden_output, axis=2)  # [batch_size,self.story_length,self.hidden_size*2]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号