def init_process_group(backend, init_method='env://', **kwargs):
"""Initializes the distributed package.
Arguments:
backend (str): Name of the backend to use. Depending on build-time configuration
valid values include: ``tcp``, ``mpi`` and ``gloo``.
init_method (str, optional): URL specifying how to initialize the package.
world_size (int, optional): Number of processes participating in the job.
rank (int, optional): Rank of the current process.
group_name (str, optional): Group name. See description of init methods.
To enable ``backend == mpi``, PyTorch needs to built from source on a system that
supports MPI.
"""
world_size = kwargs.pop('world_size', -1)
group_name = kwargs.pop('group_name', '')
rank = kwargs.pop('rank', -1)
assert len(kwargs) == 0, "got unexpected keyword arguments: %s" % ",".join(kwargs.keys())
if not is_available():
raise RuntimeError("PyTorch built without distributed support")
global _initialized
if _initialized:
raise RuntimeError("trying to initialize torch.distributed twice!")
torch._C._dist_init_process_group(backend, init_method, world_size,
group_name, rank)
_initialized = _INITIALIZED_PG
if not torch._C._dist_init_extension(False, reduce_op, group):
raise RuntimeError("distributed module initialization failed")
评论列表
文章目录