test_communicator.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:chainermn 作者: chainer 项目源码 文件源码
def create_communicator(param, use_gpu):
    if not param.multi_node:
        ranks = _communication_utility.init_ranks(mpi_comm)
        inter_size = ranks[4]
        if inter_size > 1:
            pytest.skip('This test is for single node only')

    if use_gpu and not param.nccl1 and nccl.get_version() < 2000:
        pytest.skip('This test requires NCCL version >= 2.0')

    communicator = param.communicator_class(mpi_comm)

    if hasattr(communicator, 'intra_rank'):
        chainer.cuda.get_device(communicator.intra_rank).use()

    return communicator
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号