def __init__(self, nFeatures, args):
super().__init__()
nHidden, neq, nineq = 2*nFeatures-1,0,2*nFeatures-2
assert(neq==0)
# self.fc1 = nn.Linear(nFeatures, nHidden)
self.M = Variable(torch.tril(torch.ones(nHidden, nHidden)).cuda())
Q = 1e-8*torch.eye(nHidden)
Q[:nFeatures,:nFeatures] = torch.eye(nFeatures)
self.L = Variable(torch.potrf(Q))
self.D = Parameter(0.3*torch.randn(nFeatures-1, nFeatures))
# self.lam = Parameter(20.*torch.ones(1))
self.h = Variable(torch.zeros(nineq))
self.nFeatures = nFeatures
self.nHidden = nHidden
self.neq = neq
self.nineq = nineq
self.args = args
python类tril()的实例源码
def __init__(self, nHidden=50, nineq=200, neq=0, eps=1e-4):
super(LenetOptNet, self).__init__()
self.conv1 = nn.Conv2d(1, 20, kernel_size=5)
self.conv2 = nn.Conv2d(20, 50, kernel_size=5)
self.qp_o = nn.Linear(50*4*4, nHidden)
self.qp_z0 = nn.Linear(50*4*4, nHidden)
self.qp_s0 = nn.Linear(50*4*4, nineq)
assert(neq==0)
self.M = Variable(torch.tril(torch.ones(nHidden, nHidden)).cuda())
self.L = Parameter(torch.tril(torch.rand(nHidden, nHidden).cuda()))
self.G = Parameter(torch.Tensor(nineq,nHidden).uniform_(-1,1).cuda())
# self.z0 = Parameter(torch.zeros(nHidden).cuda())
# self.s0 = Parameter(torch.ones(nineq).cuda())
self.nHidden = nHidden
self.nineq = nineq
self.neq = neq
self.eps = eps
def forward(self, L, z):
'''
:param L: batch_size (B) x latent_size^2 (L^2)
:param z: batch_size (B) x latent_size (L)
:return: z_new = L*z
'''
# L->tril(L)
L_matrix = L.view( -1, self.args.z1_size, self.args.z1_size ) # resize to get B x L x L
LTmask = torch.tril( torch.ones(self.args.z1_size, self.args.z1_size), k=-1 ) # lower-triangular mask matrix (1s in lower triangular part)
I = Variable( torch.eye(self.args.z1_size, self.args.z1_size).expand(L_matrix.size(0), self.args.z1_size, self.args.z1_size) )
if self.args.cuda:
LTmask = LTmask.cuda()
I = I.cuda()
LTmask = Variable(LTmask)
LTmask = LTmask.unsqueeze(0).expand( L_matrix.size(0), self.args.z1_size, self.args.z1_size ) # 1 x L x L -> B x L x L
LT = torch.mul( L_matrix, LTmask ) + I # here we get a batch of lower-triangular matrices with ones on diagonal
# z_new = L * z
z_new = torch.bmm( LT , z.unsqueeze(2) ).squeeze(2) # B x L x L * B x L x 1 -> B x L
return z_new
def test_tril(self):
x = torch.rand(SIZE, SIZE)
res1 = torch.tril(x)
res2 = torch.Tensor()
torch.tril(res2, x)
self.assertEqual(res1, res2, 0)
def __init__(self, nFeatures, nHidden, nCls, bn, nineq=200, neq=0, eps=1e-4):
super().__init__()
self.nFeatures = nFeatures
self.nHidden = nHidden
self.bn = bn
self.nCls = nCls
if bn:
self.bn1 = nn.BatchNorm1d(nHidden)
self.bn2 = nn.BatchNorm1d(nCls)
self.fc1 = nn.Linear(nFeatures, nHidden)
self.fc2 = nn.Linear(nHidden, nCls)
# self.qp_z0 = nn.Linear(nCls, nCls)
# self.qp_s0 = nn.Linear(nCls, nineq)
assert(neq==0)
self.M = Variable(torch.tril(torch.ones(nCls, nCls)).cuda())
self.L = Parameter(torch.tril(torch.rand(nCls, nCls).cuda()))
self.G = Parameter(torch.Tensor(nineq,nCls).uniform_(-1,1).cuda())
self.z0 = Parameter(torch.zeros(nCls).cuda())
self.s0 = Parameter(torch.ones(nineq).cuda())
self.nineq = nineq
self.neq = neq
self.eps = eps
def test_tril(self):
x = torch.rand(SIZE, SIZE)
res1 = torch.tril(x)
res2 = torch.Tensor()
torch.tril(x, out=res2)
self.assertEqual(res1, res2, 0)
def test_tril(self):
x = torch.rand(SIZE, SIZE)
res1 = torch.tril(x)
res2 = torch.Tensor()
torch.tril(x, out=res2)
self.assertEqual(res1, res2, 0)
def _make_cov(S):
L = torch.tril(torch.rand(S, S))
return torch.mm(L, L.t())
def test_tril(self):
x = torch.rand(SIZE, SIZE)
res1 = torch.tril(x)
res2 = torch.Tensor()
torch.tril(x, out=res2)
self.assertEqual(res1, res2, 0)
def test_potrf(self):
root = Variable(torch.tril(torch.rand(S, S)), requires_grad=True)
def run_test(upper):
def func(root):
x = torch.mm(root, root.t())
return torch.potrf(x, upper)
gradcheck(func, [root])
gradgradcheck(func, [root])
run_test(upper=True)
run_test(upper=False)
def test_tril(self):
x = torch.rand(SIZE, SIZE)
res1 = torch.tril(x)
res2 = torch.Tensor()
torch.tril(x, out=res2)
self.assertEqual(res1, res2, 0)
def test_trtrs(self):
a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
(-6.05, -3.30, 5.36, -4.44, 1.08),
(-0.45, 2.58, -2.70, 0.27, 9.04),
(8.32, 2.71, 4.35, -7.17, 2.14),
(-9.67, -5.14, -7.26, 6.08, -6.87))).t()
b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03),
(-1.56, 4.00, -8.67, 1.75, 2.86),
(9.81, -4.09, -4.57, -8.61, 8.99))).t()
U = torch.triu(a)
L = torch.tril(a)
# solve Ux = b
x = torch.trtrs(b, U)[0]
self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
x = torch.trtrs(b, U, True, False, False)[0]
self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
# solve Lx = b
x = torch.trtrs(b, L, False)[0]
self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
x = torch.trtrs(b, L, False, False, False)[0]
self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
# solve U'x = b
x = torch.trtrs(b, U, True, True)[0]
self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
x = torch.trtrs(b, U, True, True, False)[0]
self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
# solve U'x = b by manual transposition
y = torch.trtrs(b, U.t(), False, False)[0]
self.assertLessEqual(x.dist(y), 1e-12)
# solve L'x = b
x = torch.trtrs(b, L, False, True)[0]
self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
x = torch.trtrs(b, L, False, True, False)[0]
self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
# solve L'x = b by manual transposition
y = torch.trtrs(b, L.t(), True, False)[0]
self.assertLessEqual(x.dist(y), 1e-12)
# test reuse
res1 = torch.trtrs(b,a)[0]
ta = torch.Tensor()
tb = torch.Tensor()
torch.trtrs(tb,ta,b,a)
self.assertEqual(res1, tb, 0)
tb.zero_()
torch.trtrs(tb,ta,b,a)
self.assertEqual(res1, tb, 0)
def __init__(self, nFeatures, args):
super(OptNet, self).__init__()
nHidden, neq, nineq = 2*nFeatures-1,0,2*nFeatures-2
assert(neq==0)
self.fc1 = nn.Linear(nFeatures, nHidden)
self.M = Variable(torch.tril(torch.ones(nHidden, nHidden)).cuda())
if args.tvInit:
Q = 1e-8*torch.eye(nHidden)
Q[:nFeatures,:nFeatures] = torch.eye(nFeatures)
self.L = Parameter(torch.potrf(Q))
D = torch.zeros(nFeatures-1, nFeatures)
D[:nFeatures-1,:nFeatures-1] = torch.eye(nFeatures-1)
D[:nFeatures-1,1:nFeatures] -= torch.eye(nFeatures-1)
G_ = block((( D, -torch.eye(nFeatures-1)),
(-D, -torch.eye(nFeatures-1))))
self.G = Parameter(G_)
self.s0 = Parameter(torch.ones(2*nFeatures-2)+1e-6*torch.randn(2*nFeatures-2))
G_pinv = (G_.t().mm(G_)+1e-5*torch.eye(nHidden)).inverse().mm(G_.t())
self.z0 = Parameter(-G_pinv.mv(self.s0.data)+1e-6*torch.randn(nHidden))
lam = 21.21
W_fc1, b_fc1 = self.fc1.weight, self.fc1.bias
W_fc1.data[:,:] = 1e-3*torch.randn((2*nFeatures-1, nFeatures))
# W_fc1.data[:,:] = 0.0
W_fc1.data[:nFeatures,:nFeatures] += -torch.eye(nFeatures)
# b_fc1.data[:] = torch.zeros(2*nFeatures-1)
b_fc1.data[:] = 0.0
b_fc1.data[nFeatures:2*nFeatures-1] = lam
else:
self.L = Parameter(torch.tril(torch.rand(nHidden, nHidden)))
self.G = Parameter(torch.Tensor(nineq,nHidden).uniform_(-1,1))
self.z0 = Parameter(torch.zeros(nHidden))
self.s0 = Parameter(torch.ones(nineq))
self.nFeatures = nFeatures
self.nHidden = nHidden
self.neq = neq
self.nineq = nineq
self.args = args
def test_trtrs(self):
a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
(-6.05, -3.30, 5.36, -4.44, 1.08),
(-0.45, 2.58, -2.70, 0.27, 9.04),
(8.32, 2.71, 4.35, -7.17, 2.14),
(-9.67, -5.14, -7.26, 6.08, -6.87))).t()
b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03),
(-1.56, 4.00, -8.67, 1.75, 2.86),
(9.81, -4.09, -4.57, -8.61, 8.99))).t()
U = torch.triu(a)
L = torch.tril(a)
# solve Ux = b
x = torch.trtrs(b, U)[0]
self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
x = torch.trtrs(b, U, True, False, False)[0]
self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
# solve Lx = b
x = torch.trtrs(b, L, False)[0]
self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
x = torch.trtrs(b, L, False, False, False)[0]
self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
# solve U'x = b
x = torch.trtrs(b, U, True, True)[0]
self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
x = torch.trtrs(b, U, True, True, False)[0]
self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
# solve U'x = b by manual transposition
y = torch.trtrs(b, U.t(), False, False)[0]
self.assertLessEqual(x.dist(y), 1e-12)
# solve L'x = b
x = torch.trtrs(b, L, False, True)[0]
self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
x = torch.trtrs(b, L, False, True, False)[0]
self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
# solve L'x = b by manual transposition
y = torch.trtrs(b, L.t(), True, False)[0]
self.assertLessEqual(x.dist(y), 1e-12)
# test reuse
res1 = torch.trtrs(b, a)[0]
ta = torch.Tensor()
tb = torch.Tensor()
torch.trtrs(b, a, out=(tb, ta))
self.assertEqual(res1, tb, 0)
tb.zero_()
torch.trtrs(b, a, out=(tb, ta))
self.assertEqual(res1, tb, 0)
def test_trtrs(self):
a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
(-6.05, -3.30, 5.36, -4.44, 1.08),
(-0.45, 2.58, -2.70, 0.27, 9.04),
(8.32, 2.71, 4.35, -7.17, 2.14),
(-9.67, -5.14, -7.26, 6.08, -6.87))).t()
b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03),
(-1.56, 4.00, -8.67, 1.75, 2.86),
(9.81, -4.09, -4.57, -8.61, 8.99))).t()
U = torch.triu(a)
L = torch.tril(a)
# solve Ux = b
x = torch.trtrs(b, U)[0]
self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
x = torch.trtrs(b, U, True, False, False)[0]
self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
# solve Lx = b
x = torch.trtrs(b, L, False)[0]
self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
x = torch.trtrs(b, L, False, False, False)[0]
self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
# solve U'x = b
x = torch.trtrs(b, U, True, True)[0]
self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
x = torch.trtrs(b, U, True, True, False)[0]
self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
# solve U'x = b by manual transposition
y = torch.trtrs(b, U.t(), False, False)[0]
self.assertLessEqual(x.dist(y), 1e-12)
# solve L'x = b
x = torch.trtrs(b, L, False, True)[0]
self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
x = torch.trtrs(b, L, False, True, False)[0]
self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
# solve L'x = b by manual transposition
y = torch.trtrs(b, L.t(), True, False)[0]
self.assertLessEqual(x.dist(y), 1e-12)
# test reuse
res1 = torch.trtrs(b, a)[0]
ta = torch.Tensor()
tb = torch.Tensor()
torch.trtrs(b, a, out=(tb, ta))
self.assertEqual(res1, tb, 0)
tb.zero_()
torch.trtrs(b, a, out=(tb, ta))
self.assertEqual(res1, tb, 0)
def test_trtrs(self):
a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
(-6.05, -3.30, 5.36, -4.44, 1.08),
(-0.45, 2.58, -2.70, 0.27, 9.04),
(8.32, 2.71, 4.35, -7.17, 2.14),
(-9.67, -5.14, -7.26, 6.08, -6.87))).t()
b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03),
(-1.56, 4.00, -8.67, 1.75, 2.86),
(9.81, -4.09, -4.57, -8.61, 8.99))).t()
U = torch.triu(a)
L = torch.tril(a)
# solve Ux = b
x = torch.trtrs(b, U)[0]
self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
x = torch.trtrs(b, U, True, False, False)[0]
self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
# solve Lx = b
x = torch.trtrs(b, L, False)[0]
self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
x = torch.trtrs(b, L, False, False, False)[0]
self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
# solve U'x = b
x = torch.trtrs(b, U, True, True)[0]
self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
x = torch.trtrs(b, U, True, True, False)[0]
self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
# solve U'x = b by manual transposition
y = torch.trtrs(b, U.t(), False, False)[0]
self.assertLessEqual(x.dist(y), 1e-12)
# solve L'x = b
x = torch.trtrs(b, L, False, True)[0]
self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
x = torch.trtrs(b, L, False, True, False)[0]
self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
# solve L'x = b by manual transposition
y = torch.trtrs(b, L.t(), True, False)[0]
self.assertLessEqual(x.dist(y), 1e-12)
# test reuse
res1 = torch.trtrs(b, a)[0]
ta = torch.Tensor()
tb = torch.Tensor()
torch.trtrs(b, a, out=(tb, ta))
self.assertEqual(res1, tb, 0)
tb.zero_()
torch.trtrs(b, a, out=(tb, ta))
self.assertEqual(res1, tb, 0)
def test_trtrs(self):
a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
(-6.05, -3.30, 5.36, -4.44, 1.08),
(-0.45, 2.58, -2.70, 0.27, 9.04),
(8.32, 2.71, 4.35, -7.17, 2.14),
(-9.67, -5.14, -7.26, 6.08, -6.87))).t()
b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03),
(-1.56, 4.00, -8.67, 1.75, 2.86),
(9.81, -4.09, -4.57, -8.61, 8.99))).t()
U = torch.triu(a)
L = torch.tril(a)
# solve Ux = b
x = torch.trtrs(b, U)[0]
self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
x = torch.trtrs(b, U, True, False, False)[0]
self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
# solve Lx = b
x = torch.trtrs(b, L, False)[0]
self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
x = torch.trtrs(b, L, False, False, False)[0]
self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
# solve U'x = b
x = torch.trtrs(b, U, True, True)[0]
self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
x = torch.trtrs(b, U, True, True, False)[0]
self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
# solve U'x = b by manual transposition
y = torch.trtrs(b, U.t(), False, False)[0]
self.assertLessEqual(x.dist(y), 1e-12)
# solve L'x = b
x = torch.trtrs(b, L, False, True)[0]
self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
x = torch.trtrs(b, L, False, True, False)[0]
self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
# solve L'x = b by manual transposition
y = torch.trtrs(b, L.t(), True, False)[0]
self.assertLessEqual(x.dist(y), 1e-12)
# test reuse
res1 = torch.trtrs(b, a)[0]
ta = torch.Tensor()
tb = torch.Tensor()
torch.trtrs(b, a, out=(tb, ta))
self.assertEqual(res1, tb, 0)
tb.zero_()
torch.trtrs(b, a, out=(tb, ta))
self.assertEqual(res1, tb, 0)