pal.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:chainerrl 作者: chainer 项目源码 文件源码
def _compute_y_and_t(self, exp_batch, gamma):

        batch_state = exp_batch['state']
        batch_size = len(exp_batch['reward'])

        qout = self.q_function(batch_state)

        batch_actions = exp_batch['action']
        batch_q = qout.evaluate_actions(batch_actions)

        # Compute target values
        with chainer.no_backprop_mode():

            target_qout = self.target_q_function(batch_state)

            batch_next_state = exp_batch['next_state']

            with state_kept(self.target_q_function):
                target_next_qout = self.target_q_function(
                    batch_next_state)
            next_q_max = F.reshape(target_next_qout.max, (batch_size,))

            batch_rewards = exp_batch['reward']
            batch_terminal = exp_batch['is_state_terminal']

            # T Q: Bellman operator
            t_q = batch_rewards + self.gamma * \
                (1.0 - batch_terminal) * next_q_max

            # T_PAL Q: persistent advantage learning operator
            cur_advantage = F.reshape(
                target_qout.compute_advantage(batch_actions), (batch_size,))
            next_advantage = F.reshape(
                target_next_qout.compute_advantage(batch_actions),
                (batch_size,))
            tpal_q = t_q + self.alpha * \
                F.maximum(cur_advantage, next_advantage)

        return batch_q, tpal_q
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号