def _trust_region_loss(model, distribution, ref_distribution, loss, threshold):
# Compute gradients from original loss
model.zero_grad()
loss.backward(retain_graph=True)
# Gradients should be treated as constants (not using detach as volatility can creep in when double backprop is not implemented)
g = [Variable(param.grad.data.clone()) for param in model.parameters() if param.grad is not None]
model.zero_grad()
# KL divergence k ? ??0?DKL[?(?|s_i; ?_a) || ?(?|s_i; ?)]
kl = F.kl_div(distribution.log(), ref_distribution, size_average=False)
# Compute gradients from (negative) KL loss (increases KL divergence)
(-kl).backward(retain_graph=True)
k = [Variable(param.grad.data.clone()) for param in model.parameters() if param.grad is not None]
model.zero_grad()
# Compute dot products of gradients
k_dot_g = sum(torch.sum(k_p * g_p) for k_p, g_p in zip(k, g))
k_dot_k = sum(torch.sum(k_p ** 2) for k_p in k)
# Compute trust region update
if k_dot_k.data[0] > 0:
trust_factor = ((k_dot_g - threshold) / k_dot_k).clamp(min=0)
else:
trust_factor = Variable(torch.zeros(1))
# z* = g - max(0, (k^T?g - ?) / ||k||^2_2)?k
z_star = [g_p - trust_factor.expand_as(k_p) * k_p for g_p, k_p in zip(g, k)]
trust_loss = 0
for param, z_star_p in zip(model.parameters(), z_star):
trust_loss += (param * z_star_p).sum()
return trust_loss
# Trains model
评论列表
文章目录