def load_model_state(filename, model, cuda_device=None):
if not os.path.exists(filename):
return None, [], None
if cuda_device is None:
state = torch.load(filename)
else:
state = torch.load(
filename,
map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device))
)
state = _upgrade_state_dict(state)
state['model'] = model.upgrade_state_dict(state['model'])
# load model parameters
try:
model.load_state_dict(state['model'])
except:
raise Exception('Cannot load model parameters from checkpoint, '
'please ensure that the architectures match')
return state['extra_state'], state['optimizer_history'], state['last_optimizer_state']
评论列表
文章目录