def cross_validation(self,x,y,tau,v=5):
'''
Function that computes the cross validation accuracy for the value tau of the regularization
Input:
x : the training samples
y : the labels
tau : a range of values to be tested
v : the number of fold
Output:
err : the estimated error with cross validation for all tau's value
'''
## Initialization
ns = x.shape[0] # Number of samples
np = tau.size # Number of parameters to test
cv = CV() # Initialization of the indices for the cross validation
cv.split_data_class(y)
err = sp.zeros(np) # Initialization of the errors
## Create GMM model for each fold
model_cv = []
for i in range(v):
model_cv.append(GMMR())
model_cv[i].learn(x[cv.it[i],:], y[cv.it[i]])
## Initialization of the pool of processes
pool = mp.Pool()
processes = [pool.apply_async(predict,args=(tau,model_cv[i],x[cv.iT[i],:],y[cv.iT[i]])) for i in range(v)]
pool.close()
pool.join()
for p in processes:
err += p.get()
err /= v
## Free memory
for model in model_cv:
del model
elf
del processes,pool,model_cv
return tau[err.argmax()],err
评论列表
文章目录