def all_reduce(input, output=None, op=SUM, stream=None):
comm = communicator()
if output is None:
output = input
if stream is not None:
stream = stream.cuda_stream
data_type = nccl_types[input.type()]
check_error(lib.ncclAllReduce(
ctypes.c_void_p(input.data_ptr()),
ctypes.c_void_p(output.data_ptr()),
ctypes.c_size_t(input.numel()),
data_type,
op,
comm,
ctypes.c_void_p(stream)))
return output
评论列表
文章目录