def gather(tensor, **kwargs):
"""Gathers a list of tensors in a single process.
Arguments:
tensor (Tensor): Input tensor.
dst (int): Destination rank. Required in all processes except the one that
is receiveing the data.
gather_list (list[Tensor]): List of appropriately-sized tensors to
use for received data. Required only in the receiving process.
group (optional): Group of the collective.
"""
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
my_rank = get_rank()
dst = kwargs.pop('dst', my_rank)
gather_list = kwargs.pop('gather_list', None)
_group = kwargs.pop('group', group.WORLD)
if kwargs:
raise RuntimeError("got unexpected kwargs")
if dst == my_rank:
if gather_list is None:
raise RuntimeError("gather_list is a required argument in gather destination")
return torch._C._dist_gather_recv(gather_list, tensor, _group)
else:
if gather_list:
raise RuntimeError("non-empty gather_list can be given only to gather destination")
return torch._C._dist_gather_send(tensor, dst, _group)
评论列表
文章目录