RNNModel.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:DeeplearningForTextClassification 作者: zldeng 项目源码 文件源码
def inference(self):
        '''
        1. embedding layer
        2. Bi-LSTM layer
        3. concat Bi-LSTM output
        4. FC(full connected) layer
        5. softmax layer
        '''

        #embedding layer
        with tf.device('/cpu:0'),tf.name_scope('embedding'):
            self.embedded_words = tf.nn.embedding_lookup(self.Embedding,self.input_x)

        #Bi-LSTM layer

        lstm_fw_cell = rnn.BasicLSTMCell(self.hidden_size)
        lstm_bw_cell = rnn.BasicLSTMCell(self.hidden_size)

        if self.dropout_keep_prob is not None:
            lstm_fw_cell = rnn.DropoutWrapper(lstm_fw_cell,output_keep_prob = self.dropout_keep_prob)
            lstm_bw_cell = rnn.DropoutWrapper(lstm_bw_cell,output_keep_prob = self.dropout_keep_prob)

        outputs,output_states = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell,lstm_bw_cell,self.embedded_words,dtype = tf.float32)



        #BI-GRU layer
        '''
        gru_fw_cell = rnn.GRUCell(self.hidden_size)
        gru_bw_cell = rnn.GRUCell(self.hidden_size)

        if self.dropout_keep_prob is not None:
            gru_fw_cell = rnn.DropoutWrapper(gru_fw_cell,output_keep_prob = self.dropout_keep_prob)
            gru_bw_cell = rnn.DropoutWrapper(gru_bw_cell,output_keep_prob = self.dropout_keep_prob)

        outputs,output_states = tf.nn.bidirectional_dynamic_rnn(gru_fw_cell,gru_bw_cell,self.embedded_words,dtype = tf.float32)
        '''
        #concat output
        #each output in outputs is [batch sequence_length hidden_size]

        #concat forward output and backward output
        output_cnn = tf.concat(outputs,axis = 2) #[batch sequence_length 2*hidden_size]

        output_cnn_last = tf.reduce_mean(output_cnn,axis = 1) #[batch_size,2*hidden_size]

        #FC layer
        with tf.name_scope('output'):
            self.score = tf.matmul(output_cnn_last,self.W_projection) + self.b_projection


        return self.score
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号