def forward(self, contexts, word_embedding):
'''
contexts.size() -> (#batch, #sentence, #token)
word_embedding() -> (#batch, #sentence x #token, #embedding)
position_encoding() -> (#batch, #sentence, #embedding)
facts.size() -> (#batch, #sentence, #hidden = #embedding)
'''
batch_num, sen_num, token_num = contexts.size()
contexts = contexts.view(batch_num, -1)
contexts = word_embedding(contexts)
contexts = contexts.view(batch_num, sen_num, token_num, -1)
contexts = position_encoding(contexts)
contexts = self.dropout(contexts)
h0 = Variable(torch.zeros(2, batch_num, self.hidden_size).cuda())
facts, hdn = self.gru(contexts, h0)
facts = facts[:, :, :hidden_size] + facts[:, :, hidden_size:]
return facts
babi_main.py 文件源码
python
阅读 27
收藏 0
点赞 0
评论 0
评论列表
文章目录