forward.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:cebl 作者: idfah 项目源码 文件源码
def __init__(self, x, g, nHidden=10, transFunc=transfer.lecun,
                 weightInitFunc=pinit.lecun, penalty=None, elastic=1.0,
                 optimFunc=optim.scg, **kwargs):
        x = np.asarray(x)
        g = np.asarray(g)
        self.dtype = np.result_type(x.dtype, g.dtype)

        self.flattenOut = False if g.ndim > 1 else True

        Regression.__init__(self, util.colmat(x).shape[1],
                            util.colmat(g).shape[1])
        optim.Optable.__init__(self)

        self.nHidden = nHidden if util.isiterable(nHidden) else (nHidden,)
        self.nHLayers = len(self.nHidden)

        self.layerDims = [(self.nIn+1, self.nHidden[0])]
        for l in xrange(1, self.nHLayers):
            self.layerDims.append((self.nHidden[l-1]+1, self.nHidden[l]))
        self.layerDims.append((self.nHidden[-1]+1, self.nOut))

        self.transFunc = transFunc if util.isiterable(transFunc) \
                else (transFunc,) * self.nHLayers
        assert len(self.transFunc) == self.nHLayers

        views = util.packedViews(self.layerDims, dtype=self.dtype)
        self.pw  = views[0]
        self.hws = views[1:-1]
        self.vw  = views[-1]

        if not util.isiterable(weightInitFunc): 
            weightInitFunc = (weightInitFunc,) * (self.nHLayers+1)
        assert len(weightInitFunc) == (len(self.hws) + 1)

        # initialize weights
        for hw, wif in zip(self.hws, weightInitFunc):
            hw[...] = wif(hw.shape).astype(self.dtype, copy=False)
        self.vw[...] = weightInitFunc[-1](self.vw.shape).astype(self.dtype, copy=False)

        self.penalty = penalty
        if self.penalty is not None:
            if not util.isiterable(self.penalty):
                self.penalty = (self.penalty,) * (self.nHLayers+1)
        assert (self.penalty is None) or (len(self.penalty) == (len(self.hws) + 1))

        self.elastic = elastic if util.isiterable(elastic) \
                else (elastic,) * (self.nHLayers+1)
        assert (len(self.elastic) == (len(self.hws) + 1))

        # train the network
        if optimFunc is not None:
            self.train(x, g, optimFunc, **kwargs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号