def forward(self, facts, G):
'''
facts.size() -> (#batch, #sentence, #hidden = #embedding)
fact.size() -> (#batch, #hidden = #embedding)
G.size() -> (#batch, #sentence)
g.size() -> (#batch, )
C.size() -> (#batch, #hidden)
'''
batch_num, sen_num, embedding_size = facts.size()
C = Variable(torch.zeros(self.hidden_size)).cuda()
for sid in range(sen_num):
fact = facts[:, sid, :]
g = G[:, sid]
if sid == 0:
C = C.unsqueeze(0).expand_as(fact)
C = self.AGRUCell(fact, C, g)
return C
babi_main.py 文件源码
python
阅读 27
收藏 0
点赞 0
评论 0
评论列表
文章目录