def forward(self, x): gate = F.sigmoid(self.gate(x)) return torch.mul(self.active(self.h(x)), gate) + torch.mul(x, (1 - gate))