dqn_update.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:categorical-dqn 作者: floringogianu 项目源码 文件源码
def accumulate_gradient(self, batch_sz, states, actions, rewards,
                            next_states, mask):
        """ Compute the temporal difference error.
            td_error = (r + gamma * max Q(s_,a)) - Q(s,a)
        """
        states = Variable(states)
        actions = Variable(actions)
        rewards = Variable(rewards.squeeze())
        next_states = Variable(next_states, volatile=True)

        # Compute Q(s, a)
        q_values = self.policy(states)
        q_values = q_values.gather(1, actions)

        # Compute Q(s_, a)
        q_target_values = Variable(torch.zeros(batch_sz).type(self.dtype.FT))

        # Bootstrap for non-terminal states
        q_target_values[mask] = self.target_policy(next_states).max(
                1, keepdim=True)[0][mask]
        q_target_values.volatile = False      # So we don't mess the huber loss
        expected_q_values = (q_target_values * self.gamma) + rewards

        # Compute Huber loss
        loss = F.smooth_l1_loss(q_values, expected_q_values)

        # Accumulate gradients
        loss.backward()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号