a1_seq2seq_attention_model.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:text_classification 作者: brightmart 项目源码 文件源码
def gru_forward(self, embedded_words,gru_cell, reverse=False):
        """
        :param embedded_words:[None,sequence_length, self.embed_size]
        :return:forward hidden state: a list.length is sentence_length, each element is [batch_size,hidden_size]
        """
        # split embedded_words
        embedded_words_splitted = tf.split(embedded_words, self.sequence_length,axis=1)  # it is a list,length is sentence_length, each element is [batch_size,1,embed_size]
        embedded_words_squeeze = [tf.squeeze(x, axis=1) for x in embedded_words_splitted]  # it is a list,length is sentence_length, each element is [batch_size,embed_size]
        h_t = tf.ones((self.batch_size,self.hidden_size))
        h_t_list = []
        if reverse:
            embedded_words_squeeze.reverse()
        for time_step, Xt in enumerate(embedded_words_squeeze):  # Xt: [batch_size,embed_size]
            h_t = gru_cell(Xt,h_t) #h_t:[batch_size,embed_size]<------Xt:[batch_size,embed_size];h_t:[batch_size,embed_size]
            h_t_list.append(h_t)
        if reverse:
            h_t_list.reverse()
        return h_t_list  # a list,length is sentence_length, each element is [batch_size,hidden_size]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号