def __init__(self, nFeatures, nHidden, nCls, bn, nineq=200, neq=0, eps=1e-4):
super().__init__()
self.nFeatures = nFeatures
self.nHidden = nHidden
self.bn = bn
self.nCls = nCls
if bn:
self.bn1 = nn.BatchNorm1d(nHidden)
self.bn2 = nn.BatchNorm1d(nCls)
self.fc1 = nn.Linear(nFeatures, nHidden)
self.fc2 = nn.Linear(nHidden, nCls)
# self.qp_z0 = nn.Linear(nCls, nCls)
# self.qp_s0 = nn.Linear(nCls, nineq)
assert(neq==0)
self.M = Variable(torch.tril(torch.ones(nCls, nCls)).cuda())
self.L = Parameter(torch.tril(torch.rand(nCls, nCls).cuda()))
self.G = Parameter(torch.Tensor(nineq,nCls).uniform_(-1,1).cuda())
self.z0 = Parameter(torch.zeros(nCls).cuda())
self.s0 = Parameter(torch.ones(nineq).cuda())
self.nineq = nineq
self.neq = neq
self.eps = eps
评论列表
文章目录