nccl.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:ParlAI 作者: facebookresearch 项目源码 文件源码
def all_reduce(input, output=None, op=SUM, stream=None):
    comm = communicator()
    if output is None:
        output = input
    if stream is not None:
        stream = stream.cuda_stream
    data_type = nccl_types[input.type()]
    check_error(lib.ncclAllReduce(
        ctypes.c_void_p(input.data_ptr()),
        ctypes.c_void_p(output.data_ptr()),
        ctypes.c_size_t(input.numel()),
        data_type,
        op,
        comm,
        ctypes.c_void_p(stream)))
    return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号