models.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:optnet 作者: locuslab 项目源码 文件源码
def __init__(self, nHidden=50, nineq=200, neq=0, eps=1e-4):
        super(LenetOptNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, kernel_size=5)
        self.conv2 = nn.Conv2d(20, 50, kernel_size=5)

        self.qp_o = nn.Linear(50*4*4, nHidden)
        self.qp_z0 = nn.Linear(50*4*4, nHidden)
        self.qp_s0 = nn.Linear(50*4*4, nineq)

        assert(neq==0)
        self.M = Variable(torch.tril(torch.ones(nHidden, nHidden)).cuda())
        self.L = Parameter(torch.tril(torch.rand(nHidden, nHidden).cuda()))
        self.G = Parameter(torch.Tensor(nineq,nHidden).uniform_(-1,1).cuda())
        # self.z0 = Parameter(torch.zeros(nHidden).cuda())
        # self.s0 = Parameter(torch.ones(nineq).cuda())

        self.nHidden = nHidden
        self.nineq = nineq
        self.neq = neq
        self.eps = eps
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号