utils.py 文件源码

python
阅读 46 收藏 0 点赞 0 评论 0

项目:ParlAI 作者: facebookresearch 项目源码 文件源码
def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device=None):
    if not os.path.exists(filename):
        return None, []
    if cuda_device is None:
        state = torch.load(filename)
    else:
        state = torch.load(
            filename,
            map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device))
        )
    state = _upgrade_state_dict(state)

    # load model parameters
    model.load_state_dict(state['model'])

    # only load optimizer and lr_scheduler if they match with the checkpoint
    optim_history = state['optimizer_history']
    last_optim = optim_history[-1]
    if last_optim['criterion_name'] == criterion.__class__.__name__:
        optimizer.load_state_dict(state['last_optimizer_state'])
        lr_scheduler.best = last_optim['best_loss']

    return state['extra_state'], optim_history
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号