def forward(self, mu, sig):
nBatch, n = mu.size()
# Find the solution via sequential quadratic programming,
# not preserving gradients
z0 = Variable(1. * mu.data, requires_grad=False)
mu0 = Variable(1. * mu.data, requires_grad=False)
sig0 = Variable(1. * sig.data, requires_grad=False)
for i in range(20):
dg = GLinearApprox(self.params["gamma_under"],
self.params["gamma_over"])(z0, mu0, sig0)
d2g = GQuadraticApprox(self.params["gamma_under"],
self.params["gamma_over"])(z0, mu0, sig0)
z0_new = SolveSchedulingQP(self.params)(z0, mu0, dg, d2g)
solution_diff = (z0-z0_new).norm().data[0]
print("+ SQP Iter: {}, Solution diff = {}".format(i, solution_diff))
z0 = z0_new
if solution_diff < 1e-10:
break
# Now that we found the solution, compute the gradient-propagating
# version at the solution
dg = GLinearApprox(self.params["gamma_under"],
self.params["gamma_over"])(z0, mu, sig)
d2g = GQuadraticApprox(self.params["gamma_under"],
self.params["gamma_over"])(z0, mu, sig)
return SolveSchedulingQP(self.params)(z0, mu, dg, d2g)
评论列表
文章目录