def node_forward(self, inputs, child_c, child_h):
child_h_sum = F.torch.sum(torch.squeeze(child_h, 1), 0)
i = F.sigmoid(self.ix(inputs) + self.ih(child_h_sum))
o = F.sigmoid(self.ox(inputs) + self.oh(child_h_sum))
u = F.tanh(self.ux(inputs) + self.uh(child_h_sum))
# add extra singleton dimension
fx = F.torch.unsqueeze(self.fx(inputs), 1)
f = F.torch.cat([self.fh(child_hi) + fx for child_hi in child_h], 0)
f = F.sigmoid(f)
# removing extra singleton dimension
f = F.torch.unsqueeze(f, 1)
fc = F.torch.squeeze(F.torch.mul(f, child_c), 1)
c = F.torch.mul(i, u) + F.torch.sum(fc, 0)
h = F.torch.mul(o, F.tanh(c))
return c, h
评论列表
文章目录