def search(config):
get_params = get_agent_class(config).get_random_config
params_keys = list(get_params().keys())
nb_hp_params = len(params_keys)
if config['debug']:
print('*** Number of hyper-parameters: %d' % nb_hp_params)
config['max_iter'] = 5 if config['debug'] else 500
futures = []
with concurrent.futures.ProcessPoolExecutor(min(multiprocessing.cpu_count(), config['nb_process'])) as executor:
nb_config = 5 if config['debug'] else 200 * nb_hp_params
for i in range(nb_config):
params = get_params(config["fixed_params"])
config.update(params)
config['random_seed'] = 1
futures.append(executor.submit(test_params, i, copy.deepcopy(config), copy.deepcopy(params)))
concurrent.futures.wait(futures)
results = [future.result() for future in futures]
results = sorted(results, key=lambda result: result['mean_score'], reverse=True)
best_params = results[0]['params']
return {
'best_params': best_params
, 'results': results
}
评论列表
文章目录