def get_update(self):
actions, log_actions, rewards, critics, entropies, states, advantages = self._sample()
# Compute auxiliary losses
critics = self.critic(states)
critic_loss = (rewards - critics).pow(2).mean()
critic_loss = self.critic_weight * critic_loss
entropy_loss = entropies.mean()
entropy_loss = - self.entropy_weight * entropy_loss
# Compute policy loss
advantages = advantages.detach().view(-1, 1)
new_actions = self.policy(states)
log_probs = new_actions.compute_log_prob(actions)
ratios = (log_probs - log_actions.detach()).exp()
surr = ratios.view(-1, 1) * advantages
clipped = th.clamp(ratios, 1.0 - self.clip, 1.0 + self.clip).view(-1, 1) * advantages
policy_loss = - th.min(surr, clipped).mean()
# Proceed to optimization
loss = policy_loss + critic_loss + entropy_loss
if self.epoch_optimized == self.num_epochs:
loss.backward(retain_graph=False)
else:
loss.backward(retain_graph=True)
if self.grad_clip > 0.0:
th.nn.utils.clip_grad_norm(self.parameters(), self.grad_clip)
# Store statistics
self.stats['Num. Updates'] += 1.0
self.stats['Critic Loss'] += critic_loss.data[0]
self.stats['Entropy Loss'] += entropy_loss.data[0]
self.stats['Policy Loss'] += policy_loss.data[0]
self.stats['Total Loss'] += loss.data[0]
return [p.grad.clone() for p in self.parameters()]
评论列表
文章目录