def forward(self, inputs): inputs, (hx, cx) = inputs x = F.elu(self.linear1(inputs)) hx, cx = self.lstm(x, (hx, cx)) x = hx return self.critic_linear(x), self.actor_linear(x), (hx, cx)