def node_forward(self, inputs, child_c, child_h, training):
child_h_sum = F.torch.sum(torch.squeeze(child_h, 1), 0, keepdim = True)
i = F.sigmoid(self.ix(inputs)+self.ih(child_h_sum))
o = F.sigmoid(self.ox(inputs)+self.oh(child_h_sum))
u = F.tanh(self.ux(inputs)+self.uh(child_h_sum))
# add extra singleton dimension
fx = F.torch.unsqueeze(self.fx(inputs), 1)
f = F.torch.cat([self.fh(child_hi) + torch.squeeze(fx, 1) for child_hi in child_h], 0)
# f = torch.squeeze(f, 0)
f = F.sigmoid(f)
# removing extra singleton dimension
f = F.torch.unsqueeze(f, 1)
fc = F.torch.squeeze(F.torch.mul(f, child_c), 1)
idx = Var(torch.multinomial(torch.ones(child_c.size(0)), 1), requires_grad=False)
if self.cuda_flag:
idx = idx.cuda()
c = zoneout(
current_input=F.torch.mul(i, u) + F.torch.sum(fc, 0, keepdim=True),
previous_input=F.torch.squeeze(child_c.index_select(0, idx), 0) if self.zoneout_choose_child else F.torch.sum(torch.squeeze(child_c, 1), 0, keepdim=True),
p=self.recurrent_dropout_c,
training=training,
mask=self.mask if self.commons_mask else None
)
h = zoneout(
current_input=F.torch.mul(o, F.tanh(c)),
previous_input=F.torch.squeeze(child_h.index_select(0, idx), 0) if self.zoneout_choose_child else child_h_sum,
p=self.recurrent_dropout_h,
training=training,
mask=self.mask if self.commons_mask else None
)
return c, h
评论列表
文章目录