def test_torch():
import torch
from torch.autograd import Variable
torch.manual_seed(0)
nx, nineq, neq = 4, 6, 7
Q = torch.randn(nx, nx)
G = torch.randn(nineq, nx)
A = torch.randn(neq, nx)
D = torch.diag(torch.rand(nineq))
K_ = torch.cat((
torch.cat((Q, torch.zeros(nx, nineq).type_as(Q), G.t(), A.t()), 1),
torch.cat((torch.zeros(nineq, nx).type_as(Q), D,
torch.eye(nineq).type_as(Q),
torch.zeros(nineq, neq).type_as(Q)), 1),
torch.cat((G, torch.eye(nineq).type_as(Q), torch.zeros(
nineq, nineq + neq).type_as(Q)), 1),
torch.cat((A, torch.zeros((neq, nineq + nineq + neq))), 1)
))
K = block((
(Q, 0, G.t(), A.t()),
(0, D, 'I', 0),
(G, 'I', 0, 0),
(A, 0, 0, 0)
))
assert (K - K_).norm() == 0.0
K = block((
(Variable(Q), 0, G.t(), Variable(A.t())),
(0, Variable(D), 'I', 0),
(Variable(G), 'I', 0, 0),
(A, 0, 0, 0)
))
assert (K.data - K_).norm() == 0.0
评论列表
文章目录