model_classes.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:e2e-model-learning 作者: locuslab 项目源码 文件源码
def forward(self, mu, sig):
        nBatch, n = mu.size()

        # Find the solution via sequential quadratic programming, 
        # not preserving gradients
        z0 = Variable(1. * mu.data, requires_grad=False)
        mu0 = Variable(1. * mu.data, requires_grad=False)
        sig0 = Variable(1. * sig.data, requires_grad=False)
        for i in range(20):
            dg = GLinearApprox(self.params["gamma_under"], 
                self.params["gamma_over"])(z0, mu0, sig0)
            d2g = GQuadraticApprox(self.params["gamma_under"], 
                self.params["gamma_over"])(z0, mu0, sig0)
            z0_new = SolveSchedulingQP(self.params)(z0, mu0, dg, d2g)
            solution_diff = (z0-z0_new).norm().data[0]
            print("+ SQP Iter: {}, Solution diff = {}".format(i, solution_diff))
            z0 = z0_new
            if solution_diff < 1e-10:
                break

        # Now that we found the solution, compute the gradient-propagating 
        # version at the solution
        dg = GLinearApprox(self.params["gamma_under"], 
            self.params["gamma_over"])(z0, mu, sig)
        d2g = GQuadraticApprox(self.params["gamma_under"], 
            self.params["gamma_over"])(z0, mu, sig)
        return SolveSchedulingQP(self.params)(z0, mu, dg, d2g)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号