def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device=None):
if not os.path.exists(filename):
return None, []
if cuda_device is None:
state = torch.load(filename)
else:
state = torch.load(
filename,
map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device))
)
state = _upgrade_state_dict(state)
# load model parameters
model.load_state_dict(state['model'])
# only load optimizer and lr_scheduler if they match with the checkpoint
optim_history = state['optimizer_history']
last_optim = optim_history[-1]
if last_optim['criterion_name'] == criterion.__class__.__name__:
optimizer.load_state_dict(state['last_optimizer_state'])
lr_scheduler.best = last_optim['best_loss']
return state['extra_state'], optim_history
评论列表
文章目录