def train_model(model,
features,
labels,
tile_size,
model_id,
nb_epoch=10,
checkpoints=False,
tensorboard=False):
"""Train a model with the given features and labels."""
# The features and labels are a list of triples when passed
# to the function. Each triple contains the tile and information
# about its source image and its postion in the source. To train
# the model we extract just the tiles.
X, y = get_matrix_form(features, labels, tile_size)
X = normalise_input(X)
# Directory which is used to store the model and its weights.
model_dir = os.path.join(MODELS_DIR, model_id)
checkpointer = None
if checkpoints:
checkpoints_file = os.path.join(model_dir, "weights.hdf5")
checkpointer = ModelCheckpoint(checkpoints_file)
tensorboarder = None
if tensorboard:
log_dir = os.path.join(TENSORBOARD_DIR, model_id)
tensorboarder = TensorBoard(log_dir=log_dir)
callbacks = [c for c in [checkpointer, tensorboarder] if c]
print("Start training.")
model.fit(X, y, nb_epoch=nb_epoch, callbacks=callbacks, validation_split=0.1)
save_model(model, model_dir)
return model
评论列表
文章目录