acer.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:chainerrl 作者: chainer 项目源码 文件源码
def update(self, t_start, t_stop, R, states, actions, rewards, values,
               action_values, action_distribs, action_distribs_mu,
               avg_action_distribs):

        assert np.isscalar(R)

        total_loss = self.compute_loss(
            t_start=t_start,
            t_stop=t_stop,
            R=R,
            states=states,
            actions=actions,
            rewards=rewards,
            values=values,
            action_values=action_values,
            action_distribs=action_distribs,
            action_distribs_mu=action_distribs_mu,
            avg_action_distribs=avg_action_distribs)

        # Compute gradients using thread-specific model
        self.model.zerograds()
        total_loss.backward()
        # Copy the gradients to the globally shared model
        self.shared_model.zerograds()
        copy_param.copy_grad(
            target_link=self.shared_model, source_link=self.model)
        # Update the globally shared model
        if self.process_idx == 0:
            norm = sum(np.sum(np.square(param.grad))
                       for param in self.optimizer.target.params())
            self.logger.debug('grad norm:%s', norm)
        self.optimizer.update()

        self.sync_parameters()
        if isinstance(self.model, Recurrent):
            self.model.unchain_backward()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号