pc_train.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:malmo-challenge 作者: Kaixhin 项目源码 文件源码
def _train(args, T, model, shared_model, shared_average_model, optimiser, policies, Qs, Vs, actions, rewards, Qret, average_policies, target_class, pred_class, old_policies=None):
  off_policy = old_policies is not None
  policy_loss, value_loss, class_loss = 0, 0, 0

  # Calculate n-step returns in forward view, stepping backwards from the last state
  t = len(rewards)
  for i in reversed(range(t)):
    # Importance sampling weights ? ? ?(?|s_i) / µ(?|s_i); 1 for on-policy
    rho = off_policy and policies[i].detach() / old_policies[i] or Variable(torch.ones(1, ACTION_SIZE))

    # Qret ? r_i + ?Qret
    Qret = rewards[i] + args.discount * Qret
    # Advantage A ? Qret - V(s_i; ?)
    A = Qret - Vs[i]

    # Log policy log(?(a_i|s_i; ?))
    log_prob = policies[i].gather(1, actions[i]).log()
    # g ? min(c, ?_a_i)????log(?(a_i|s_i; ?))?A
    single_step_policy_loss = -(rho.gather(1, actions[i]).clamp(max=args.trace_max) * log_prob * A).mean(0)  # Average over batch
    # Off-policy bias correction
    if off_policy:
      # g ? g + ?_a [1 - c/?_a]_+??(a|s_i; ?)????log(?(a|s_i; ?))?(Q(s_i, a; ?) - V(s_i; ?)
      bias_weight = (1 - args.trace_max / rho).clamp(min=0) * policies[i]
      single_step_policy_loss -= (bias_weight * policies[i].log() * (Qs[i].detach() - Vs[i].expand_as(Qs[i]).detach())).sum(1).mean(0)
    if args.trust_region:
      # Policy update d? ? d? + ??/???z*
      policy_loss += _trust_region_loss(model, policies[i], average_policies[i], single_step_policy_loss, args.trust_region_threshold)
    else:
      # Policy update d? ? d? + ??/???g
      policy_loss += single_step_policy_loss

    # Entropy regularisation d? ? d? - ????H(?(s_i; ?))
    policy_loss += args.entropy_weight * -(policies[i].log() * policies[i]).sum(1).mean(0)

    # Value update d? ? d? - ???1/2?(Qret - Q(s_i, a_i; ?))^2
    Q = Qs[i].gather(1, actions[i])
    value_loss += ((Qret - Q) ** 2 / 2).mean(0)  # Least squares loss

    # Truncated importance weight ?¯_a_i = min(1, ?_a_i)
    truncated_rho = rho.gather(1, actions[i]).clamp(max=1)
    # Qret ? ?¯_a_i?(Qret - Q(s_i, a_i; ?)) + V(s_i; ?)
    Qret = truncated_rho * (Qret - Q.detach()) + Vs[i].detach()

    # Train classification loss
    class_loss += F.binary_cross_entropy(pred_class[i], target_class)

  # Optionally normalise loss by number of time steps
  if not args.no_time_normalisation:
    policy_loss /= t
    value_loss /= t
    class_loss /= t
  # Update networks
  _update_networks(args, T, model, shared_model, shared_average_model, policy_loss + value_loss + class_loss, optimiser)


# Acts and trains model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号