def _get_categorical(self, next_states, rewards, mask):
batch_sz = next_states.size(0)
gamma = self.gamma
# Compute probabilities p(x, a)
probs = self.target_policy(next_states).data
qs = torch.mul(probs, self.support.expand_as(probs))
argmax_a = qs.sum(2).max(1)[1].unsqueeze(1).unsqueeze(1)
action_mask = argmax_a.expand(batch_sz, 1, self.atoms_no)
qa_probs = probs.gather(1, action_mask).squeeze()
# Mask gamma and reshape it torgether with rewards to fit p(x,a).
rewards = rewards.expand_as(qa_probs)
gamma = (mask.float() * gamma).expand_as(qa_probs)
# Compute projection of the application of the Bellman operator.
bellman_op = rewards + gamma * self.support.unsqueeze(0).expand_as(rewards)
bellman_op = torch.clamp(bellman_op, self.v_min, self.v_max)
# Compute categorical indices for distributing the probability
m = self.m.fill_(0)
b = (bellman_op - self.v_min) / self.delta_z
l = b.floor().long()
u = b.ceil().long()
# Distribute probability
"""
for i in range(batch_sz):
for j in range(self.atoms_no):
uidx = u[i][j]
lidx = l[i][j]
m[i][lidx] = m[i][lidx] + qa_probs[i][j] * (uidx - b[i][j])
m[i][uidx] = m[i][uidx] + qa_probs[i][j] * (b[i][j] - lidx)
for i in range(batch_sz):
m[i].index_add_(0, l[i], qa_probs[i] * (u[i].float() - b[i]))
m[i].index_add_(0, u[i], qa_probs[i] * (b[i] - l[i].float()))
"""
# Optimized by https://github.com/tudor-berariu
offset = torch.linspace(0, ((batch_sz - 1) * self.atoms_no), batch_sz)\
.type(self.dtype.LT)\
.unsqueeze(1).expand(batch_sz, self.atoms_no)
m.view(-1).index_add_(0, (l + offset).view(-1),
(qa_probs * (u.float() - b)).view(-1))
m.view(-1).index_add_(0, (u + offset).view(-1),
(qa_probs * (b - l.float())).view(-1))
return Variable(m)
categorical_update.py 文件源码
python
阅读 36
收藏 0
点赞 0
评论 0
评论列表
文章目录