train.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:sockeye 作者: awslabs 项目源码 文件源码
def determine_context(args: argparse.Namespace, exit_stack: ExitStack) -> List[mx.Context]:
    """
    Determine the context we should run on (CPU or GPU).

    :param args: Arguments as returned by argparse.
    :param exit_stack: An ExitStack from contextlib.
    :return: A list with the context(s) to run on.
    """
    if args.use_cpu:
        logger.info("Training Device: CPU")
        context = [mx.cpu()]
    else:
        num_gpus = utils.get_num_gpus()
        check_condition(num_gpus >= 1,
                        "No GPUs found, consider running on the CPU with --use-cpu "
                        "(note: check depends on nvidia-smi and this could also mean that the nvidia-smi "
                        "binary isn't on the path).")
        if args.disable_device_locking:
            context = utils.expand_requested_device_ids(args.device_ids)
        else:
            context = exit_stack.enter_context(utils.acquire_gpus(args.device_ids, lock_dir=args.lock_dir))
        if args.batch_type == C.BATCH_TYPE_SENTENCE:
            check_condition(args.batch_size % len(context) == 0, "When using multiple devices the batch size must be "
                                                                 "divisible by the number of devices. Choose a batch "
                                                                 "size that is a multiple of %d." % len(context))
        logger.info("Training Device(s): GPU %s", context)
        context = [mx.gpu(gpu_id) for gpu_id in context]
    return context
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号