def create_training_model(model_config: model.ModelConfig,
args: argparse.Namespace,
context: List[mx.Context],
train_iter: data_io.BaseParallelSampleIter,
lr_scheduler_instance: lr_scheduler.LearningRateScheduler,
resume_training: bool,
training_state_dir: str) -> training.TrainingModel:
"""
Create a training model and load the parameters from disk if needed.
:param model_config: The configuration for the model.
:param args: Arguments as returned by argparse.
:param context: The context(s) to run on.
:param train_iter: The training data iterator.
:param lr_scheduler_instance: The learning rate scheduler.
:param resume_training: When True, the model will be loaded from disk.
:param training_state_dir: Directory where the training state is stored.
:return: The training model.
"""
training_model = training.TrainingModel(config=model_config,
context=context,
train_iter=train_iter,
bucketing=not args.no_bucketing,
lr_scheduler=lr_scheduler_instance,
gradient_compression_params=gradient_compression_params(args))
# We may consider loading the params in TrainingModule, for consistency
# with the training state saving
if resume_training:
logger.info("Found partial training in directory %s. Resuming from saved state.", training_state_dir)
training_model.load_params_from_file(os.path.join(training_state_dir, C.TRAINING_STATE_PARAMS_NAME))
elif args.params:
logger.info("Training will initialize from parameters loaded from '%s'", args.params)
training_model.load_params_from_file(args.params)
return training_model
评论列表
文章目录