reinforce.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:drl.pth 作者: seba-1511 项目源码 文件源码
def get_update(self):
        num_traj = loss_stats = critics_stats = entropy_stats = policy_stats = 0.0
        all_rewards, all_advantages = self.advantage(self.rewards, self.critics, self.terminals)
#        for actions_ep, rewards_ep, critics_ep, entropy_ep, terminals_ep in zip(self.actions, self.rewards, self.critics, self.entropies, self.terminals):
        for actions_ep, rewards_ep, advantage_ep, critics_ep, entropy_ep, terminals_ep in zip(self.actions, all_rewards, all_advantages, self.critics, self.entropies, self.terminals):
            if len(actions_ep) > 0:
                # Compute advantages
                #rewards_ep = V(T(rewards_ep))
                critics_ep = th.cat(critics_ep, 0).view(-1)
                #rewards_ep, advantage_ep = self.advantage(rewards_ep, critics_ep, terminals_ep)
                # Compute losses
                critic_loss = (rewards_ep - critics_ep).pow(2).mean()
                entropy_loss = th.cat(entropy_ep).mean()
                critic_loss = self.critic_weight * critic_loss
                entropy_loss = - self.entropy_weight * entropy_loss
                # Compute policy gradients
                policy_loss = 0.0
                for action, advantage in zip(actions_ep, advantage_ep):
                    policy_loss = policy_loss - action.log_prob.mean() * advantage.data[0]
                loss = policy_loss + critic_loss + entropy_loss
                loss.backward(retain_graph=True)
                if self.grad_clip > 0.0:
                    th.nn.utils.clip_grad_norm(self.parameters(), self.grad_clip)
                # Update running statistics
                loss_stats += loss.data[0]
                critics_stats += critic_loss.data[0]
                entropy_stats += entropy_loss.data[0]
                policy_stats += policy_loss.data[0]
                num_traj += 1.0

        # Store statistics
        self.stats['Num. Updates'] += 1.0
        self.stats['Num. Trajectories'] += num_traj
        self.stats['Critic Loss'] += critics_stats / num_traj
        self.stats['Entropy Loss'] += entropy_stats / num_traj
        self.stats['Policy Loss'] += policy_stats / num_traj
        self.stats['Total Loss'] += loss_stats / num_traj
        self.stats['Num. Steps'] += self.steps
        self._reset()
        return [p.grad.clone() for p in self.parameters()]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号