ppo.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:chainerrl 作者: chainer 项目源码 文件源码
def _lossfun(self,
                 distribs, vs_pred, log_probs,
                 vs_pred_old, target_log_probs,
                 advs, vs_teacher):
        prob_ratio = F.exp(log_probs - target_log_probs)
        ent = distribs.entropy

        prob_ratio = F.expand_dims(prob_ratio, axis=-1)
        loss_policy = - F.mean(F.minimum(
            prob_ratio * advs,
            F.clip(prob_ratio, 1-self.clip_eps, 1+self.clip_eps) * advs))

        if self.clip_eps_vf is None:
            loss_value_func = F.mean_squared_error(vs_pred, vs_teacher)
        else:
            loss_value_func = F.mean(F.maximum(
                F.square(vs_pred - vs_teacher),
                F.square(_elementwise_clip(vs_pred,
                                           vs_pred_old - self.clip_eps_vf,
                                           vs_pred_old + self.clip_eps_vf)
                         - vs_teacher)
                ))

        loss_entropy = -F.mean(ent)

        # Update stats
        self.average_loss_policy += (
            (1 - self.average_loss_decay) *
            (cuda.to_cpu(loss_policy.data) - self.average_loss_policy))
        self.average_loss_value_func += (
            (1 - self.average_loss_decay) *
            (cuda.to_cpu(loss_value_func.data) - self.average_loss_value_func))
        self.average_loss_entropy += (
            (1 - self.average_loss_decay) *
            (cuda.to_cpu(loss_entropy.data) - self.average_loss_entropy))

        return (
            loss_policy
            + self.value_func_coef * loss_value_func
            + self.entropy_coef * loss_entropy
            )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号