multiprocessing_trainer.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:fairseq-py 作者: facebookresearch 项目源码 文件源码
def __init__(self, args, model, criterion, device_ids=None,
                 multiprocessing_method='spawn'):
        if device_ids is None:
            device_ids = tuple(range(torch.cuda.device_count()))
        super().__init__(device_ids, multiprocessing_method)

        if not torch.cuda.is_available():
            raise NotImplementedError('Training on CPU is not supported')
        model = model.share_memory()
        nccl_uid = nccl.get_unique_id()
        self.criterion = criterion

        Future.gen_list([
            self.call_async(rank, '_async_init', args=args, model=model,
                            criterion=criterion, nccl_uid=nccl_uid)
            for rank in range(self.num_replicas)
        ])

        self._grads_initialized = False
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号