def all_reduce(tensor, op=reduce_op.SUM, group=group.WORLD):
"""Reduces the tensor data across all machines in such a way that all get
the final result.
After the call ``tensor`` is going to be bitwise identical in all processes.
Arguments:
tensor (Tensor): Input and output of the collective. The function
operates in-place.
op (optional): One of the values from ``torch.distributed.reduce_op``
enum. Specifies an operation used for element-wise reductions.
group (optional): Group of the collective.
"""
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return torch._C._dist_all_reduce(tensor, op, group)
评论列表
文章目录