def accumulate_gradient(self, batch_sz, states, actions, rewards,
next_states, mask):
""" Compute the temporal difference error.
td_error = (r + gamma * max Q(s_,a)) - Q(s,a)
"""
states = Variable(states)
actions = Variable(actions)
rewards = Variable(rewards.squeeze())
next_states = Variable(next_states, volatile=True)
# Compute Q(s, a)
q_values = self.policy(states)
q_values = q_values.gather(1, actions)
# Compute Q(s_, a)
q_target_values = Variable(torch.zeros(batch_sz).type(self.dtype.FT))
# Bootstrap for non-terminal states
q_target_values[mask] = self.target_policy(next_states).max(
1, keepdim=True)[0][mask]
q_target_values.volatile = False # So we don't mess the huber loss
expected_q_values = (q_target_values * self.gamma) + rewards
# Compute Huber loss
loss = F.smooth_l1_loss(q_values, expected_q_values)
# Accumulate gradients
loss.backward()
评论列表
文章目录