VUVclassification.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:jingjuPhoneticSegmentation 作者: ronggong 项目源码 文件源码
def gmm_cv(fv_train,target_train,fv_test,target_test):

    ####---- cross validation of train dataset, gridsearch the best parameters for gmm

    n_classes           = len(np.unique(target_train))
    n_components_max    = 7

    n_components_range  = range(1, n_components_max)
    cv_types            = ['spherical', 'tied', 'diag', 'full']
    for cv_type in cv_types:
        for n_components in n_components_range:
            # Fit a mixture of Gaussians with EM
            gmm         = mixture.GMM(n_components=n_components, covariance_type=cv_type)
            gmm.means_  = np.array([fv_train[target_train == i].mean(axis=0)
                                  for i in xrange(n_classes)])
            gmm.fit(fv_train_transformed)

            target_train_pred   = gmm.predict(fv_train_transformed)
            train_accuracy      = np.mean(target_train_pred == target_train) * 100

            print cv_type, n_components, ' Train accuracy: %.1f' % train_accuracy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号