babi_main.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:Dynamic-memory-networks-plus-Pytorch 作者: dandelin 项目源码 文件源码
def forward(self, facts, G):
        '''
        facts.size() -> (#batch, #sentence, #hidden = #embedding)
        fact.size() -> (#batch, #hidden = #embedding)
        G.size() -> (#batch, #sentence)
        g.size() -> (#batch, )
        C.size() -> (#batch, #hidden)
        '''
        batch_num, sen_num, embedding_size = facts.size()
        C = Variable(torch.zeros(self.hidden_size)).cuda()
        for sid in range(sen_num):
            fact = facts[:, sid, :]
            g = G[:, sid]
            if sid == 0:
                C = C.unsqueeze(0).expand_as(fact)
            C = self.AGRUCell(fact, C, g)
        return C
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号