def scatter(tensor, **kwargs):
"""Scatters a list of tensors to all processes in a group.
Each process will receive exactly one tensor and store its data in the
``tensor`` argument.
Arguments:
tensor (Tensor): Output tensor.
src (int): Source rank. Required in all processes except the one that
is sending the data.
scatter_list (list[Tensor]): List of tensors to scatter. Required only
in the process that is sending the data.
group (optional): Group of the collective.
"""
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
my_rank = get_rank()
src = kwargs.pop('src', my_rank)
scatter_list = kwargs.pop('scatter_list', None)
_group = kwargs.pop('group', group.WORLD)
if kwargs:
raise RuntimeError("got unexpected kwargs")
if src == my_rank:
if scatter_list is None:
raise RuntimeError("scatter_list is a required argument in scatter source")
return torch._C._dist_scatter_send(scatter_list, tensor, _group)
else:
if scatter_list:
raise RuntimeError("non-empty can be given only to scatter source")
return torch._C._dist_scatter_recv(tensor, src, _group)
评论列表
文章目录