categorical_update.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:categorical-dqn 作者: floringogianu 项目源码 文件源码
def accumulate_gradient(self, batch_sz, states, actions, rewards,
                            next_states, mask):
        """ Compute the difference between the return distributions of Q(s,a)
            and TQ(s_,a).
        """
        states = Variable(states)
        actions = Variable(actions)
        next_states = Variable(next_states, volatile=True)

        # Compute probabilities of Q(s,a*)
        q_probs = self.policy(states)
        actions = actions.view(batch_sz, 1, 1)
        action_mask = actions.expand(batch_sz, 1, self.atoms_no)
        qa_probs = q_probs.gather(1, action_mask).squeeze()

        # Compute distribution of Q(s_,a)
        target_qa_probs = self._get_categorical(next_states, rewards, mask)

        # Compute the cross-entropy of phi(TZ(x_,a)) || Z(x,a)
        qa_probs.data.clamp_(0.01, 0.99)  # Tudor's trick for avoiding nans
        loss = - torch.sum(target_qa_probs * torch.log(qa_probs))

        # Accumulate gradients
        loss.backward()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号