def forward(self, inputs):
inputs, (hx, cx) = inputs
x = F.elu(self.conv1(inputs))
x = F.elu(self.conv2(x))
x = F.elu(self.conv3(x))
x = F.elu(self.conv4(x))
x = x.view(-1, 32 * 3 * 3)
hx, cx = self.lstm(x, (hx, cx))
x = hx
return self.critic_linear(x), self.actor_linear(x), (hx, cx)
评论列表
文章目录