def all_gather(tensor_list, tensor, group=group.WORLD):
"""Gathers tensors from the whole group in a list.
Arguments:
tensor_list (list[Tensor]): Output list. It should contain
correctly-sized tensors to be used for output of the collective.
tensor (Tensor): Tensor to be broadcast from current process.
group (optional): Group of the collective.
"""
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
if _backend != dist_backend.NCCL:
return torch._C._dist_all_gather(tensor_list, tensor, group)
else:
return all_gather_multigpu([tensor_list], [tensor], group)
评论列表
文章目录