def forward(self, x):
nBatch = x.size(0)
L = self.M*self.L
Q = L.mm(L.t()) + self.args.eps*Variable(torch.eye(self.nHidden)).cuda()
Q = Q.unsqueeze(0).expand(nBatch, self.nHidden, self.nHidden)
nI = Variable(-torch.eye(self.nFeatures-1).type_as(Q.data))
G = torch.cat((
torch.cat(( self.D, nI), 1),
torch.cat((-self.D, nI), 1)
))
G = G.unsqueeze(0).expand(nBatch, self.nineq, self.nHidden)
h = self.h.unsqueeze(0).expand(nBatch, self.nineq)
e = Variable(torch.Tensor())
# p = torch.cat((-x, self.lam.unsqueeze(0).expand(nBatch, self.nFeatures-1)), 1)
p = torch.cat((-x, Parameter(13.*torch.ones(nBatch, self.nFeatures-1).cuda())), 1)
x = QPFunction()(Q.double(), p.double(), G.double(), h.double(), e, e).float()
x = x[:,:self.nFeatures]
return x
评论列表
文章目录