def get_world_size(): """Returns the number of processes in the distributed group.""" assert torch.distributed._initialized return torch._C._dist_get_num_processes()