ddpg.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:chainerrl 作者: chainer 项目源码 文件源码
def compute_actor_loss(self, batch):
        """Compute loss for actor.

        Preconditions:
          q_function must have seen up to s_{t-1} and s_{t-1}.
          policy must have seen up to s_{t-1}.
        Preconditions:
          q_function must have seen up to s_t and s_t.
          policy must have seen up to s_t.
        """

        batch_state = batch['state']
        batch_action = batch['action']
        batch_size = len(batch_action)

        # Estimated policy observes s_t
        onpolicy_actions = self.policy(batch_state).sample()

        # Q(s_t, mu(s_t)) is evaluated.
        # This should not affect the internal state of Q.
        with state_kept(self.q_function):
            q = self.q_function(batch_state, onpolicy_actions)

        # Estimated Q-function observes s_t and a_t
        if isinstance(self.q_function, Recurrent):
            self.q_function.update_state(batch_state, batch_action)

        # Avoid the numpy #9165 bug (see also: chainer #2744)
        q = q[:, :]

        # Since we want to maximize Q, loss is negation of Q
        loss = - F.sum(q) / batch_size

        # Update stats
        self.average_actor_loss *= self.average_loss_decay
        self.average_actor_loss += ((1 - self.average_loss_decay) *
                                    float(loss.data))
        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号