model.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:WaterNet 作者: treigerm 项目源码 文件源码
def train_model(model,
                features,
                labels,
                tile_size,
                model_id,
                nb_epoch=10,
                checkpoints=False,
                tensorboard=False):
    """Train a model with the given features and labels."""

    # The features and labels are a list of triples when passed
    # to the function. Each triple contains the tile and information
    # about its source image and its postion in the source. To train
    # the model we extract just the tiles.
    X, y = get_matrix_form(features, labels, tile_size)

    X = normalise_input(X)

    # Directory which is used to store the model and its weights.
    model_dir = os.path.join(MODELS_DIR, model_id)

    checkpointer = None
    if checkpoints:
        checkpoints_file = os.path.join(model_dir, "weights.hdf5")
        checkpointer = ModelCheckpoint(checkpoints_file)

    tensorboarder = None
    if tensorboard:
        log_dir = os.path.join(TENSORBOARD_DIR, model_id)
        tensorboarder = TensorBoard(log_dir=log_dir)

    callbacks = [c for c in [checkpointer, tensorboarder] if c]

    print("Start training.")
    model.fit(X, y, nb_epoch=nb_epoch, callbacks=callbacks, validation_split=0.1)

    save_model(model, model_dir)
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号