def create_communicator(gpu):
if gpu:
communicator = chainermn.create_communicator('hierarchical')
chainer.cuda.get_device(communicator.intra_rank).use()
else:
communicator = chainermn.create_communicator('naive')
if communicator.size < 2:
pytest.skip("This test is for multinode only")
rank_next = (communicator.rank + 1) % communicator.size
rank_prev = (communicator.rank - 1) % communicator.size
return communicator, rank_next, rank_prev
评论列表
文章目录