models.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:optnet 作者: locuslab 项目源码 文件源码
def forward(self, x):
        nBatch = x.size(0)

        L = self.M*self.L
        Q = L.mm(L.t()) + self.args.eps*Variable(torch.eye(self.nHidden)).cuda()
        Q = Q.unsqueeze(0).expand(nBatch, self.nHidden, self.nHidden)
        nI = Variable(-torch.eye(self.nFeatures-1).type_as(Q.data))
        G = torch.cat((
              torch.cat(( self.D, nI), 1),
              torch.cat((-self.D, nI), 1)
        ))
        G = G.unsqueeze(0).expand(nBatch, self.nineq, self.nHidden)
        h = self.h.unsqueeze(0).expand(nBatch, self.nineq)
        e = Variable(torch.Tensor())
        # p = torch.cat((-x, self.lam.unsqueeze(0).expand(nBatch, self.nFeatures-1)), 1)
        p = torch.cat((-x, Parameter(13.*torch.ones(nBatch, self.nFeatures-1).cuda())), 1)
        x = QPFunction()(Q.double(), p.double(), G.double(), h.double(), e, e).float()
        x = x[:,:self.nFeatures]

        return x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号