train.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:ACER 作者: Kaixhin 项目源码 文件源码
def _trust_region_loss(model, distribution, ref_distribution, loss, threshold):
  # Compute gradients from original loss
  model.zero_grad()
  loss.backward(retain_graph=True)
  # Gradients should be treated as constants (not using detach as volatility can creep in when double backprop is not implemented)
  g = [Variable(param.grad.data.clone()) for param in model.parameters() if param.grad is not None]
  model.zero_grad()

  # KL divergence k ? ??0?DKL[?(?|s_i; ?_a) || ?(?|s_i; ?)]
  kl = F.kl_div(distribution.log(), ref_distribution, size_average=False)
  # Compute gradients from (negative) KL loss (increases KL divergence)
  (-kl).backward(retain_graph=True)
  k = [Variable(param.grad.data.clone()) for param in model.parameters() if param.grad is not None]
  model.zero_grad()

  # Compute dot products of gradients
  k_dot_g = sum(torch.sum(k_p * g_p) for k_p, g_p in zip(k, g))
  k_dot_k = sum(torch.sum(k_p ** 2) for k_p in k)
  # Compute trust region update
  if k_dot_k.data[0] > 0:
    trust_factor = ((k_dot_g - threshold) / k_dot_k).clamp(min=0)
  else:
    trust_factor = Variable(torch.zeros(1))
  # z* = g - max(0, (k^T?g - ?) / ||k||^2_2)?k
  z_star = [g_p - trust_factor.expand_as(k_p) * k_p for g_p, k_p in zip(g, k)]
  trust_loss = 0
  for param, z_star_p in zip(model.parameters(), z_star):
    trust_loss += (param * z_star_p).sum()
  return trust_loss


# Trains model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号