def _async_forward(self, rank, device_id, eval=False):
if eval:
self.model.eval()
else:
self.model.train()
self.optimizer.zero_grad()
sample_size, logging_output, oom = 0, {}, False
if self._sample is not None:
try:
# calculate loss and sample size
self.loss, sample_size, logging_output = self.criterion(self.model, self._sample)
except RuntimeError as e:
if not eval and 'out of memory' in str(e):
print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id))
oom = True
self.loss = None
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
else:
raise e
return sample_size, logging_output, oom
multiprocessing_trainer.py 文件源码
python
阅读 64
收藏 0
点赞 0
评论 0
评论列表
文章目录