model.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:treehopper 作者: tomekkorbak 项目源码 文件源码
def node_forward(self, inputs, child_c, child_h, training):
        child_h_sum = F.torch.sum(torch.squeeze(child_h, 1), 0, keepdim = True)

        i = F.sigmoid(self.ix(inputs)+self.ih(child_h_sum))
        o = F.sigmoid(self.ox(inputs)+self.oh(child_h_sum))
        u = F.tanh(self.ux(inputs)+self.uh(child_h_sum))

        # add extra singleton dimension
        fx = F.torch.unsqueeze(self.fx(inputs), 1)
        f = F.torch.cat([self.fh(child_hi) + torch.squeeze(fx, 1) for child_hi in child_h], 0)
        # f = torch.squeeze(f, 0)
        f = F.sigmoid(f)
        # removing extra singleton dimension
        f = F.torch.unsqueeze(f, 1)
        fc = F.torch.squeeze(F.torch.mul(f, child_c), 1)

        idx = Var(torch.multinomial(torch.ones(child_c.size(0)), 1), requires_grad=False)
        if self.cuda_flag:
            idx = idx.cuda()

        c = zoneout(
            current_input=F.torch.mul(i, u) + F.torch.sum(fc, 0, keepdim=True),
            previous_input=F.torch.squeeze(child_c.index_select(0, idx), 0) if self.zoneout_choose_child else F.torch.sum(torch.squeeze(child_c, 1), 0, keepdim=True),
            p=self.recurrent_dropout_c,
            training=training,
            mask=self.mask if self.commons_mask else None
        )
        h = zoneout(
            current_input=F.torch.mul(o, F.tanh(c)),
            previous_input=F.torch.squeeze(child_h.index_select(0, idx), 0) if self.zoneout_choose_child else child_h_sum,
            p=self.recurrent_dropout_h,
            training=training,
            mask=self.mask if self.commons_mask else None
        )

        return c, h
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号