ica.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:cebl 作者: idfah 项目源码 文件源码
def train(self, s, kurtosis, learningRate, tolerance, maxIter, callback, verbose):
        s = self.prep(s)

        wPrev = np.empty(self.w.shape)
        grad = np.empty((self.nComp, self.nComp))

        I = np.eye(self.nComp, dtype=self.dtype)
        n = 1.0/s.shape[0]

        iteration = 0
        while True:
            y = s.dot(self.w)

            if kurtosis == 'sub':
                k = -1
            elif kurtosis == 'super':
                k = 1
            elif kurtosis == 'adapt':
                #k = np.sign(np.mean(1.0-util.fastTanh(y)**2, axis=0) *
                #            np.mean(y**2, axis=0) -
                #            np.mean(y*util.fastTanh(y), axis=0))

                k = np.sign(spstat.kurtosis(y, axis=0))
                k[np.isclose(k,0.0)] = -1.0

            grad[...] = (I - k*util.fastTanh(y).T.dot(y) - y.T.dot(y)).T.dot(self.w) * n

            wPrev[...] = self.w
            self.w += learningRate * grad

            wtol = np.max(np.abs(wPrev-self.w))

            if verbose:
                print '%d %6f' % (iteration, wtol)

            if callback is not None:
                callback(iteration, wtol)

            if wtol < tolerance:
                self.reason = 'tolerance'
                break

            elif np.max(np.abs(self.w)) > 1.0e100:
                self.reason = 'diverge'
                break

            if iteration >= maxIter:
                self.reason = 'maxiter'
                break

            iteration += 1

        if verbose:
            print 'Reason: ' + self.reason

        self.w /= np.sqrt(np.sum(self.w**2, axis=0))
        self.wInv[...] = np.linalg.pinv(self.w)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号