def answer_module(self):
""" Answer Module:generate an answer from the final memory vector.
Input:
hidden state from episodic memory module:[batch_size,hidden_size]
question:[batch_size, embedding_size]
"""
steps=self.sequence_length if self.decode_with_sequences else 1 #decoder for a list of tokens with sequence. e.g."x1 x2 x3 x4..."
a=self.m_T #init hidden state
y_pred=tf.zeros((self.batch_size,self.hidden_size)) #TODO usually we will init this as a special token '<GO>', you can change this line by pass embedding of '<GO>' from outside.
logits_list=[]
logits_return=None
for i in range(steps):
cell = rnn.GRUCell(self.hidden_size)
y_previous_q=tf.concat([y_pred,self.query_embedding],axis=1) #[batch_hidden_size*2]
_, a = cell( y_previous_q,a)
logits=tf.layers.dense(a,units=self.num_classes) #[batch_size,vocab_size]
logits_list.append(logits)
if self.decode_with_sequences:#need to get sequences.
logits_return = tf.stack(logits_list, axis=1) # [batch_size,sequence_length,num_classes]
else:#only need to get an answer, not sequences
logits_return = logits_list[0] #[batcj_size,num_classes]
return logits_return
a8_dynamic_memory_network.py 文件源码
python
阅读 34
收藏 0
点赞 0
评论 0
评论列表
文章目录