def accumulate_gradient(self, batch_sz, states, actions, rewards,
next_states, mask):
""" Compute the difference between the return distributions of Q(s,a)
and TQ(s_,a).
"""
states = Variable(states)
actions = Variable(actions)
next_states = Variable(next_states, volatile=True)
# Compute probabilities of Q(s,a*)
q_probs = self.policy(states)
actions = actions.view(batch_sz, 1, 1)
action_mask = actions.expand(batch_sz, 1, self.atoms_no)
qa_probs = q_probs.gather(1, action_mask).squeeze()
# Compute distribution of Q(s_,a)
target_qa_probs = self._get_categorical(next_states, rewards, mask)
# Compute the cross-entropy of phi(TZ(x_,a)) || Z(x,a)
qa_probs.data.clamp_(0.01, 0.99) # Tudor's trick for avoiding nans
loss = - torch.sum(target_qa_probs * torch.log(qa_probs))
# Accumulate gradients
loss.backward()
categorical_update.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录