methods.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:South-African-Heart-Disease-data-analysis-using-python 作者: khushi4tiwari 项目源码 文件源码
def CVK(X,KRange,covar_type='diag',reps=10):
    N, M = X.shape
    T = len(KRange)

    CVE = np.zeros((T,1))

    # K-fold crossvalidation
    CV = cross_validation.KFold(N,5,shuffle=True)

    for t,K in enumerate(KRange):
            print('Fitting model for K={0}\n'.format(K))

            # Fit Gaussian mixture model
            gmm = GMM(n_components=K, covariance_type=covar_type, n_init=reps, params='wmc').fit(X)

            # For each crossvalidation fold
            for train_index, test_index in CV:

                # extract training and test set for current CV fold
                X_train = X[train_index]
                X_test = X[test_index]

                # Fit Gaussian mixture model to X_train
                gmm = GMM(n_components=K, covariance_type=covar_type, n_init=reps, params='wmc').fit(X_train)

                # compute negative log likelihood of X_test
                CVE[t] += -gmm.score(X_test).sum()
                #print CVE[t]

        # Plot results
    return CVE

    #figure(); hold(True)
    #plot(KRange, 2*CVE)
    #legend(['Crossvalidation'])
    #xlabel('K')
    #show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号