def forward(self, x): trans = self.nonlin(self.lin(x)) gate = self.gate_nonlin(self.gate_lin(x)) return torch.add(torch.mul(gate, trans), torch.mul((1 - gate), x))