multiprocessing_trainer.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:fairseq-py 作者: facebookresearch 项目源码 文件源码
def _async_init(self, rank, device_id, args, model, criterion, nccl_uid):
        """Initialize child processes."""
        self.args = args

        # set CUDA device
        torch.cuda.set_device(device_id)

        # initialize NCCL
        nccl.initialize(self.num_replicas, nccl_uid, device_id)

        # copy model and criterion to current device
        self.model = model.cuda()
        self.criterion = criterion.cuda()

        # initialize optimizer and LR scheduler
        self.args.lr = list(map(float, self.args.lr.split(',')))
        self.optimizer = self._build_optimizer()
        self.lr_scheduler = self._build_lr_scheduler()

        self.loss = None
        self._max_bsz_seen = 0
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号