train.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:sockeye 作者: awslabs 项目源码 文件源码
def determine_decode_and_evaluate_context(args: argparse.Namespace,
                                          exit_stack: ExitStack,
                                          train_context: List[mx.Context]) -> Tuple[int, Optional[mx.Context]]:
    """
    Determine the number of sentences to decode and the context we should run on (CPU or GPU).

    :param args: Arguments as returned by argparse.
    :param exit_stack: An ExitStack from contextlib.
    :param train_context: Context for training.
    :return: The number of sentences to decode and a list with the context(s) to run on.
    """
    num_to_decode = args.decode_and_evaluate
    if args.optimized_metric == C.BLEU and num_to_decode == 0:
        logger.info("You chose BLEU as the optimized metric, will turn on BLEU monitoring during training. "
                    "To control how many validation sentences are used for calculating bleu use "
                    "the --decode-and-evaluate argument.")
        num_to_decode = -1

    if num_to_decode == 0:
        return 0, None

    if args.use_cpu or args.decode_and_evaluate_use_cpu:
        context = mx.cpu()
    elif args.decode_and_evaluate_device_id is not None:
        # decode device is defined from the commandline
        num_gpus = utils.get_num_gpus()
        check_condition(num_gpus >= 1,
                        "No GPUs found, consider running on the CPU with --use-cpu "
                        "(note: check depends on nvidia-smi and this could also mean that the nvidia-smi "
                        "binary isn't on the path).")

        if args.disable_device_locking:
            context = utils.expand_requested_device_ids([args.decode_and_evaluate_device_id])
        else:
            context = exit_stack.enter_context(utils.acquire_gpus([args.decode_and_evaluate_device_id],
                                                                  lock_dir=args.lock_dir))
        context = mx.gpu(context[0])

    else:
        # default decode context is the last training device
        context = train_context[-1]

    logger.info("Decode and Evaluate Device(s): %s", context)
    return num_to_decode, context
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号