def create_communicator(param, use_gpu):
if not param.multi_node:
ranks = _communication_utility.init_ranks(mpi_comm)
inter_size = ranks[4]
if inter_size > 1:
pytest.skip('This test is for single node only')
if use_gpu and not param.nccl1 and nccl.get_version() < 2000:
pytest.skip('This test requires NCCL version >= 2.0')
communicator = param.communicator_class(mpi_comm)
if hasattr(communicator, 'intra_rank'):
chainer.cuda.get_device(communicator.intra_rank).use()
return communicator
评论列表
文章目录