pcl.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:chainerrl 作者: chainer 项目源码 文件源码
def update_on_policy(self, statevar):
        assert self.t_start < self.t

        if not self.disable_online_update:
            next_values = {}
            for t in range(self.t_start + 1, self.t):
                next_values[t - 1] = self.past_values[t]
            if statevar is None:
                next_values[self.t - 1] = chainer.Variable(
                    self.xp.zeros_like(self.past_values[self.t - 1].data))
            else:
                with state_kept(self.model):
                    _, v = self.model(statevar)
                next_values[self.t - 1] = v
            log_probs = {t: self.past_action_distrib[t].log_prob(
                self.xp.asarray(self.xp.expand_dims(a, 0)))
                for t, a in self.past_actions.items()}
            self.online_batch_losses.append(self.compute_loss(
                t_start=self.t_start, t_stop=self.t,
                rewards=self.past_rewards,
                values=self.past_values,
                next_values=next_values,
                log_probs=log_probs))
            if len(self.online_batch_losses) == self.batchsize:
                loss = chainerrl.functions.sum_arrays(
                    self.online_batch_losses) / self.batchsize
                self.update(loss)
                self.online_batch_losses = []

        self.init_history_data_for_online_update()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号