ppo.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:drl.pth 作者: seba-1511 项目源码 文件源码
def get_update(self):
        actions, log_actions, rewards, critics, entropies, states, advantages = self._sample()
        # Compute auxiliary losses
        critics = self.critic(states)
        critic_loss = (rewards - critics).pow(2).mean()
        critic_loss = self.critic_weight * critic_loss
        entropy_loss = entropies.mean()
        entropy_loss = - self.entropy_weight * entropy_loss
        # Compute policy loss
        advantages = advantages.detach().view(-1, 1)
        new_actions = self.policy(states)
        log_probs = new_actions.compute_log_prob(actions)
        ratios = (log_probs - log_actions.detach()).exp()
        surr = ratios.view(-1, 1) * advantages
        clipped = th.clamp(ratios, 1.0 - self.clip, 1.0 + self.clip).view(-1, 1) * advantages
        policy_loss = - th.min(surr, clipped).mean()
        # Proceed to optimization
        loss = policy_loss + critic_loss + entropy_loss
        if self.epoch_optimized == self.num_epochs:
            loss.backward(retain_graph=False)
        else:
            loss.backward(retain_graph=True)
        if self.grad_clip > 0.0:
            th.nn.utils.clip_grad_norm(self.parameters(), self.grad_clip)

        # Store statistics
        self.stats['Num. Updates'] += 1.0
        self.stats['Critic Loss'] += critic_loss.data[0]
        self.stats['Entropy Loss'] += entropy_loss.data[0]
        self.stats['Policy Loss'] += policy_loss.data[0]
        self.stats['Total Loss'] += loss.data[0]
        return [p.grad.clone() for p in self.parameters()]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号