multiprocessing_trainer.py 文件源码

python
阅读 64 收藏 0 点赞 0 评论 0

项目:fairseq-py 作者: facebookresearch 项目源码 文件源码
def _async_forward(self, rank, device_id, eval=False):
        if eval:
            self.model.eval()
        else:
            self.model.train()
            self.optimizer.zero_grad()

        sample_size, logging_output, oom = 0, {}, False
        if self._sample is not None:
            try:
                # calculate loss and sample size
                self.loss, sample_size, logging_output = self.criterion(self.model, self._sample)
            except RuntimeError as e:
                if not eval and 'out of memory' in str(e):
                    print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id))
                    oom = True
                    self.loss = None
                    if hasattr(torch.cuda, 'empty_cache'):
                        torch.cuda.empty_cache()
                else:
                    raise e

        return sample_size, logging_output, oom
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号