classifier.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:kaggle_dsb 作者: syagev 项目源码 文件源码
def train_ensemble(trainset, valset, path_data, path_session, hyper_param):
    """Train an ensemble of models per set of hyper param.

    Args:
        trainset, valset: training and validation sets from `split_train_val()`
        path_data: /path/to/train_detections.hdf5
        path_session: string specifying the session's output path
        hyper_param: dictionary with entries as follows -
                        * epochs: number of epochs
                        * batch_sz: batch size in training
                        * batch_norm: do batch normalization?
                        * optimizer: a keras.optimizers beast
                        * lr_scheduler: a keras.callback.LearningRateScheduler

    """

    models = []
    for i, batch_sz in enumerate(hyper_param["batch_sz"]):
        for j, optimizer in enumerate(hyper_param["optimizers"]):
            for k, lr_param in enumerate(hyper_param["lr_scheduler_param"]):
                for l, dropout_rate in enumerate(hyper_param["dropout_rate"]):
                    for m, batch_norm in enumerate(hyper_param["batch_norm"]):
                        for n, pool_type in enumerate(hyper_param["pool_type"]):

                            # prepare the tasks' hyper param
                            hyper_param_ = {
                                "epochs": hyper_param["epochs"],
                                "batch_sz": batch_sz,
                                "optimizer": optimizer,
                                "lr_schedule": make_lr_scheduler(*lr_param),
                                "dropout_rate": dropout_rate,
                                "batch_norm": batch_norm,
                                "pool_type": pool_type
                                }

                            # task's path
                            session_id_ = "{}.{}_{}_{}_{}_{}_{}". \
                            format(os.path.basename(path_session),
                                   i, j, k, l, m, n)
                            path_session_ = os.path.join(path_session,
                                                         session_id_)
                            if not os.path.exists(path_session_):
                                os.mkdir(path_session_)

                            # train
                            models.append(train(
                                trainset,
                                valset,
                                path_data,
                                path_session_,
                                hyper_param_))

    # sort by validation loss
    return models.sort(key=lambda tuple: tuple[1])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号