def load_state(self, path_load):
state_dict = torch.load(path_load)
self.nn_module = state_dict['nn_module']
self._init_model()
# load nn state
module = self.net.module if isinstance(self.net, torch.nn.DataParallel) else self.net
module.cpu()
module.load_state_dict(state_dict['nn_state'])
if self.gpu_ids[0] != -1:
module.cuda(self.gpu_ids[0])
# load optimizer state
self.optimizer.state = _set_gpu_recursive(self.optimizer.state, -1)
self.optimizer.load_state_dict(state_dict['optimizer_state'])
self.optimizer.state = _set_gpu_recursive(self.optimizer.state, self.gpu_ids[0])
self.count_iter = state_dict['count_iter']
评论列表
文章目录