def __init__(self, args, model, criterion, device_ids=None,
multiprocessing_method='spawn'):
if device_ids is None:
device_ids = tuple(range(torch.cuda.device_count()))
super().__init__(device_ids, multiprocessing_method)
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
model = model.share_memory()
nccl_uid = nccl.get_unique_id()
self.criterion = criterion
Future.gen_list([
self.call_async(rank, '_async_init', args=args, model=model,
criterion=criterion, nccl_uid=nccl_uid)
for rank in range(self.num_replicas)
])
self._grads_initialized = False
multiprocessing_trainer.py 文件源码
python
阅读 37
收藏 0
点赞 0
评论 0
评论列表
文章目录