__init__.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:pytorch 作者: pytorch 项目源码 文件源码
def all_gather_multigpu(output_tensor_lists,
                        input_tensor_list,
                        group=group.WORLD):
    """Gathers tensors from the whole group in a list.
    Each tensor in tensor_list should reside on a separate GPU

    Only nccl backend is currently supported
    tensors should only be GPU tensors

    Arguments:
        output_tensor_lists (List[List[Tensor]]): Output lists. It should
            contain correctly-sized tensors on each GPU to be used for output of
            the collective.
        input_tensor_list (List[Tensor]): List of tensors(on different GPUs) to
            be broadcast from current process.
        group (optional): Group of the collective.
    """
    assert torch.distributed._initialized == _INITIALIZED_PG, \
        "collective only supported in process-group mode"

    warnings.warn("""
    ================================================================================
                                        WARNING
    ================================================================================
    all_gather_multigpu is still experimental. The API will change without
    notice and we're can't guarantee full correctness and expected performance yet.
    We'll announce it once it's ready.
    """)

    flatten_tensor_list = []
    for output_tensor_list in output_tensor_lists:
        flatten_tensor_list.append(_flatten_dense_tensors(output_tensor_list))

    ret = torch._C._dist_all_gather_multigpu(flatten_tensor_list,
                                             input_tensor_list,
                                             group)

    for output_tensor_list, flatten_tensor in zip(output_tensor_lists,
                                                  flatten_tensor_list):
        for tensor, value in zip(output_tensor_list,
                                 _unflatten_dense_tensors(flatten_tensor,
                                                          output_tensor_list)):
            tensor.copy_(value)

    return ret
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号