def _train(args, T, model, shared_model, shared_average_model, optimiser, policies, Qs, Vs, actions, rewards, Qret, average_policies, target_class, pred_class, old_policies=None):
off_policy = old_policies is not None
policy_loss, value_loss, class_loss = 0, 0, 0
# Calculate n-step returns in forward view, stepping backwards from the last state
t = len(rewards)
for i in reversed(range(t)):
# Importance sampling weights ? ? ?(?|s_i) / µ(?|s_i); 1 for on-policy
rho = off_policy and policies[i].detach() / old_policies[i] or Variable(torch.ones(1, ACTION_SIZE))
# Qret ? r_i + ?Qret
Qret = rewards[i] + args.discount * Qret
# Advantage A ? Qret - V(s_i; ?)
A = Qret - Vs[i]
# Log policy log(?(a_i|s_i; ?))
log_prob = policies[i].gather(1, actions[i]).log()
# g ? min(c, ?_a_i)????log(?(a_i|s_i; ?))?A
single_step_policy_loss = -(rho.gather(1, actions[i]).clamp(max=args.trace_max) * log_prob * A).mean(0) # Average over batch
# Off-policy bias correction
if off_policy:
# g ? g + ?_a [1 - c/?_a]_+??(a|s_i; ?)????log(?(a|s_i; ?))?(Q(s_i, a; ?) - V(s_i; ?)
bias_weight = (1 - args.trace_max / rho).clamp(min=0) * policies[i]
single_step_policy_loss -= (bias_weight * policies[i].log() * (Qs[i].detach() - Vs[i].expand_as(Qs[i]).detach())).sum(1).mean(0)
if args.trust_region:
# Policy update d? ? d? + ??/???z*
policy_loss += _trust_region_loss(model, policies[i], average_policies[i], single_step_policy_loss, args.trust_region_threshold)
else:
# Policy update d? ? d? + ??/???g
policy_loss += single_step_policy_loss
# Entropy regularisation d? ? d? - ????H(?(s_i; ?))
policy_loss += args.entropy_weight * -(policies[i].log() * policies[i]).sum(1).mean(0)
# Value update d? ? d? - ???1/2?(Qret - Q(s_i, a_i; ?))^2
Q = Qs[i].gather(1, actions[i])
value_loss += ((Qret - Q) ** 2 / 2).mean(0) # Least squares loss
# Truncated importance weight ?¯_a_i = min(1, ?_a_i)
truncated_rho = rho.gather(1, actions[i]).clamp(max=1)
# Qret ? ?¯_a_i?(Qret - Q(s_i, a_i; ?)) + V(s_i; ?)
Qret = truncated_rho * (Qret - Q.detach()) + Vs[i].detach()
# Train classification loss
class_loss += F.binary_cross_entropy(pred_class[i], target_class)
# Optionally normalise loss by number of time steps
if not args.no_time_normalisation:
policy_loss /= t
value_loss /= t
class_loss /= t
# Update networks
_update_networks(args, T, model, shared_model, shared_average_model, policy_loss + value_loss + class_loss, optimiser)
# Acts and trains model
评论列表
文章目录