gmm_ridge.py 文件源码

python
阅读 45 收藏 0 点赞 0 评论 0

项目:HistoricalMap 作者: lennepkade 项目源码 文件源码
def cross_validation(self,x,y,tau,v=5):
        ''' 
        Function that computes the cross validation accuracy for the value tau of the regularization
        Input:
            x : the training samples
            y : the labels
            tau : a range of values to be tested
            v : the number of fold
        Output:
            err : the estimated error with cross validation for all tau's value
        '''
        ## Initialization
        ns = x.shape[0]     # Number of samples
        np = tau.size       # Number of parameters to test
        cv = CV()           # Initialization of the indices for the cross validation
        cv.split_data_class(y)
        err = sp.zeros(np)  # Initialization of the errors

        ## Create GMM model for each fold
        model_cv = []
        for i in range(v):
            model_cv.append(GMMR())
            model_cv[i].learn(x[cv.it[i],:], y[cv.it[i]])

        ## Initialization of the pool of processes
        pool = mp.Pool()
        processes = [pool.apply_async(predict,args=(tau,model_cv[i],x[cv.iT[i],:],y[cv.iT[i]])) for i in range(v)]
        pool.close()
        pool.join()
        for p in processes:
            err += p.get()
        err /= v

        ## Free memory        
        for model in model_cv:
            del model
        elf
        del processes,pool,model_cv

        return tau[err.argmax()],err
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号