model.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:a3c-mujoco 作者: Feryal 项目源码 文件源码
def forward(self, non_rgb_state, rgb_state, h):
        x = self.relu(self.conv1(rgb_state))
        x = self.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = self.fc1(torch.cat((x, non_rgb_state), 1))
        h = self.lstm(x, h)  # h is (hidden state, cell state)
        x = h[0]
        policy1 = self.softmax(self.fc_actor1(x)).clamp(
            max=1 - 1e-20)  # Prevent 1s and hence NaNs
        policy2 = self.softmax(self.fc_actor2(x)).clamp(max=1 - 1e-20)
        policy3 = self.softmax(self.fc_actor3(x)).clamp(max=1 - 1e-20)
        policy4 = self.softmax(self.fc_actor4(x)).clamp(max=1 - 1e-20)
        policy5 = self.softmax(self.fc_actor5(x)).clamp(max=1 - 1e-20)
        policy6 = self.softmax(self.fc_actor6(x)).clamp(max=1 - 1e-20)
        V = self.fc_critic(x)
        return (policy1, policy2, policy3, policy4, policy5, policy6), V, h
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号