def gmm_cv(fv_train,target_train,fv_test,target_test):
####---- cross validation of train dataset, gridsearch the best parameters for gmm
n_classes = len(np.unique(target_train))
n_components_max = 7
n_components_range = range(1, n_components_max)
cv_types = ['spherical', 'tied', 'diag', 'full']
for cv_type in cv_types:
for n_components in n_components_range:
# Fit a mixture of Gaussians with EM
gmm = mixture.GMM(n_components=n_components, covariance_type=cv_type)
gmm.means_ = np.array([fv_train[target_train == i].mean(axis=0)
for i in xrange(n_classes)])
gmm.fit(fv_train_transformed)
target_train_pred = gmm.predict(fv_train_transformed)
train_accuracy = np.mean(target_train_pred == target_train) * 100
print cv_type, n_components, ' Train accuracy: %.1f' % train_accuracy
VUVclassification.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录