def get_update(self):
num_traj = loss_stats = critics_stats = entropy_stats = policy_stats = 0.0
all_rewards, all_advantages = self.advantage(self.rewards, self.critics, self.terminals)
# for actions_ep, rewards_ep, critics_ep, entropy_ep, terminals_ep in zip(self.actions, self.rewards, self.critics, self.entropies, self.terminals):
for actions_ep, rewards_ep, advantage_ep, critics_ep, entropy_ep, terminals_ep in zip(self.actions, all_rewards, all_advantages, self.critics, self.entropies, self.terminals):
if len(actions_ep) > 0:
# Compute advantages
#rewards_ep = V(T(rewards_ep))
critics_ep = th.cat(critics_ep, 0).view(-1)
#rewards_ep, advantage_ep = self.advantage(rewards_ep, critics_ep, terminals_ep)
# Compute losses
critic_loss = (rewards_ep - critics_ep).pow(2).mean()
entropy_loss = th.cat(entropy_ep).mean()
critic_loss = self.critic_weight * critic_loss
entropy_loss = - self.entropy_weight * entropy_loss
# Compute policy gradients
policy_loss = 0.0
for action, advantage in zip(actions_ep, advantage_ep):
policy_loss = policy_loss - action.log_prob.mean() * advantage.data[0]
loss = policy_loss + critic_loss + entropy_loss
loss.backward(retain_graph=True)
if self.grad_clip > 0.0:
th.nn.utils.clip_grad_norm(self.parameters(), self.grad_clip)
# Update running statistics
loss_stats += loss.data[0]
critics_stats += critic_loss.data[0]
entropy_stats += entropy_loss.data[0]
policy_stats += policy_loss.data[0]
num_traj += 1.0
# Store statistics
self.stats['Num. Updates'] += 1.0
self.stats['Num. Trajectories'] += num_traj
self.stats['Critic Loss'] += critics_stats / num_traj
self.stats['Entropy Loss'] += entropy_stats / num_traj
self.stats['Policy Loss'] += policy_stats / num_traj
self.stats['Total Loss'] += loss_stats / num_traj
self.stats['Num. Steps'] += self.steps
self._reset()
return [p.grad.clone() for p in self.parameters()]
评论列表
文章目录