fnet_model.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:pytorch_fnet 作者: AllenCellModeling 项目源码 文件源码
def load_state(self, path_load):
        state_dict = torch.load(path_load)
        self.nn_module = state_dict['nn_module']
        self._init_model()

        # load nn state
        module = self.net.module if isinstance(self.net, torch.nn.DataParallel) else self.net
        module.cpu()
        module.load_state_dict(state_dict['nn_state'])
        if self.gpu_ids[0] != -1:
            module.cuda(self.gpu_ids[0])
        # load optimizer state
        self.optimizer.state = _set_gpu_recursive(self.optimizer.state, -1)
        self.optimizer.load_state_dict(state_dict['optimizer_state'])
        self.optimizer.state = _set_gpu_recursive(self.optimizer.state, self.gpu_ids[0])

        self.count_iter = state_dict['count_iter']
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号