def train_ensemble(trainset, valset, path_data, path_session, hyper_param):
"""Train an ensemble of models per set of hyper param.
Args:
trainset, valset: training and validation sets from `split_train_val()`
path_data: /path/to/train_detections.hdf5
path_session: string specifying the session's output path
hyper_param: dictionary with entries as follows -
* epochs: number of epochs
* batch_sz: batch size in training
* batch_norm: do batch normalization?
* optimizer: a keras.optimizers beast
* lr_scheduler: a keras.callback.LearningRateScheduler
"""
models = []
for i, batch_sz in enumerate(hyper_param["batch_sz"]):
for j, optimizer in enumerate(hyper_param["optimizers"]):
for k, lr_param in enumerate(hyper_param["lr_scheduler_param"]):
for l, dropout_rate in enumerate(hyper_param["dropout_rate"]):
for m, batch_norm in enumerate(hyper_param["batch_norm"]):
for n, pool_type in enumerate(hyper_param["pool_type"]):
# prepare the tasks' hyper param
hyper_param_ = {
"epochs": hyper_param["epochs"],
"batch_sz": batch_sz,
"optimizer": optimizer,
"lr_schedule": make_lr_scheduler(*lr_param),
"dropout_rate": dropout_rate,
"batch_norm": batch_norm,
"pool_type": pool_type
}
# task's path
session_id_ = "{}.{}_{}_{}_{}_{}_{}". \
format(os.path.basename(path_session),
i, j, k, l, m, n)
path_session_ = os.path.join(path_session,
session_id_)
if not os.path.exists(path_session_):
os.mkdir(path_session_)
# train
models.append(train(
trainset,
valset,
path_data,
path_session_,
hyper_param_))
# sort by validation loss
return models.sort(key=lambda tuple: tuple[1])
评论列表
文章目录