def gru_forward(self, embedded_words,gru_cell, reverse=False):
"""
:param embedded_words:[None,sequence_length, self.embed_size]
:return:forward hidden state: a list.length is sentence_length, each element is [batch_size,hidden_size]
"""
# split embedded_words
embedded_words_splitted = tf.split(embedded_words, self.sequence_length,axis=1) # it is a list,length is sentence_length, each element is [batch_size,1,embed_size]
embedded_words_squeeze = [tf.squeeze(x, axis=1) for x in embedded_words_splitted] # it is a list,length is sentence_length, each element is [batch_size,embed_size]
h_t = tf.ones((self.batch_size,self.hidden_size))
h_t_list = []
if reverse:
embedded_words_squeeze.reverse()
for time_step, Xt in enumerate(embedded_words_squeeze): # Xt: [batch_size,embed_size]
h_t = gru_cell(Xt,h_t) #h_t:[batch_size,embed_size]<------Xt:[batch_size,embed_size];h_t:[batch_size,embed_size]
h_t_list.append(h_t)
if reverse:
h_t_list.reverse()
return h_t_list # a list,length is sentence_length, each element is [batch_size,hidden_size]
a1_seq2seq_attention_model.py 文件源码
python
阅读 37
收藏 0
点赞 0
评论 0
评论列表
文章目录