train.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:sockeye 作者: awslabs 项目源码 文件源码
def create_training_model(model_config: model.ModelConfig,
                          args: argparse.Namespace,
                          context: List[mx.Context],
                          train_iter: data_io.BaseParallelSampleIter,
                          lr_scheduler_instance: lr_scheduler.LearningRateScheduler,
                          resume_training: bool,
                          training_state_dir: str) -> training.TrainingModel:
    """
    Create a training model and load the parameters from disk if needed.

    :param model_config: The configuration for the model.
    :param args: Arguments as returned by argparse.
    :param context: The context(s) to run on.
    :param train_iter: The training data iterator.
    :param lr_scheduler_instance: The learning rate scheduler.
    :param resume_training: When True, the model will be loaded from disk.
    :param training_state_dir: Directory where the training state is stored.
    :return: The training model.
    """
    training_model = training.TrainingModel(config=model_config,
                                            context=context,
                                            train_iter=train_iter,
                                            bucketing=not args.no_bucketing,
                                            lr_scheduler=lr_scheduler_instance,
                                            gradient_compression_params=gradient_compression_params(args))

    # We may consider loading the params in TrainingModule, for consistency
    # with the training state saving
    if resume_training:
        logger.info("Found partial training in directory %s. Resuming from saved state.", training_state_dir)
        training_model.load_params_from_file(os.path.join(training_state_dir, C.TRAINING_STATE_PARAMS_NAME))
    elif args.params:
        logger.info("Training will initialize from parameters loaded from '%s'", args.params)
        training_model.load_params_from_file(args.params)

    return training_model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号