ppo.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:chainerrl 作者: chainer 项目源码 文件源码
def update(self):
        xp = self.xp

        if self.standardize_advantages:
            all_advs = xp.array([b['adv'] for b in self.memory])
            mean_advs = xp.mean(all_advs)
            std_advs = xp.std(all_advs)

        target_model = copy.deepcopy(self.model)
        dataset_iter = chainer.iterators.SerialIterator(
            self.memory, self.minibatch_size)

        dataset_iter.reset()
        while dataset_iter.epoch < self.epochs:
            batch = dataset_iter.__next__()
            states = batch_states([b['state'] for b in batch], xp, self.phi)
            actions = xp.array([b['action'] for b in batch])
            distribs, vs_pred = self.model(states)
            with chainer.no_backprop_mode():
                target_distribs, _ = target_model(states)

            advs = xp.array([b['adv'] for b in batch], dtype=xp.float32)
            if self.standardize_advantages:
                advs = (advs - mean_advs) / std_advs

            self.optimizer.update(
                self._lossfun,
                distribs, vs_pred, distribs.log_prob(actions),
                vs_pred_old=xp.array(
                    [b['v_pred'] for b in batch], dtype=xp.float32),
                target_log_probs=target_distribs.log_prob(actions),
                advs=advs,
                vs_teacher=xp.array(
                    [b['v_teacher'] for b in batch], dtype=xp.float32),
                )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号