linreg.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:cebl 作者: idfah 项目源码 文件源码
def __init__(self, x, g,
                 elastic=1.0, penalty=0.0,
                 weightInitFunc=pinit.lecun,
                 optimFunc=optim.scg, **kwargs):
        x = np.asarray(x)
        g = np.asarray(g)
        self.dtype = np.result_type(x.dtype, g.dtype)

        if g.ndim > 1:
            self.flattenOut = False
        else:
            self.flattenOut = True

        self.elastic = elastic
        self.penalty = penalty

        Regression.__init__(self, util.colmat(x).shape[1],
                            util.colmat(g).shape[1])
        optim.Optable.__init__(self)

        self.weights = weightInitFunc((self.nIn+1, self.nOut)).astype(self.dtype, copy=False)

        if optimFunc is not None:
            self.train(x, g, optimFunc, **kwargs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号