python类load()的实例源码

spatial_cnn.py 文件源码 项目:two-stream-action-recognition 作者: jeffreyhuang1 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def resume_and_evaluate(self):
        if self.resume:
            if os.path.isfile(self.resume):
                print("==> loading checkpoint '{}'".format(self.resume))
                checkpoint = torch.load(self.resume)
                self.start_epoch = checkpoint['epoch']
                self.best_prec1 = checkpoint['best_prec1']
                self.model.load_state_dict(checkpoint['state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer'])
                print("==> loaded checkpoint '{}' (epoch {}) (best_prec1 {})"
                  .format(self.resume, checkpoint['epoch'], self.best_prec1))
            else:
                print("==> no checkpoint found at '{}'".format(self.resume))
        if self.evaluate:
            self.epoch = 0
            prec1, val_loss = self.validate_1epoch()
            return
fnet_model.py 文件源码 项目:pytorch_fnet 作者: AllenCellModeling 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def load_state(self, path_load):
        state_dict = torch.load(path_load)
        self.nn_module = state_dict['nn_module']
        self._init_model()

        # load nn state
        module = self.net.module if isinstance(self.net, torch.nn.DataParallel) else self.net
        module.cpu()
        module.load_state_dict(state_dict['nn_state'])
        if self.gpu_ids[0] != -1:
            module.cuda(self.gpu_ids[0])
        # load optimizer state
        self.optimizer.state = _set_gpu_recursive(self.optimizer.state, -1)
        self.optimizer.load_state_dict(state_dict['optimizer_state'])
        self.optimizer.state = _set_gpu_recursive(self.optimizer.state, self.gpu_ids[0])

        self.count_iter = state_dict['count_iter']
train.py 文件源码 项目:samplernn-pytorch 作者: deepsound-project 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def load_last_checkpoint(checkpoints_path):
    checkpoints_pattern = os.path.join(
        checkpoints_path, SaverPlugin.last_pattern.format('*', '*')
    )
    checkpoint_paths = natsorted(glob(checkpoints_pattern))
    if len(checkpoint_paths) > 0:
        checkpoint_path = checkpoint_paths[-1]
        checkpoint_name = os.path.basename(checkpoint_path)
        match = re.match(
            SaverPlugin.last_pattern.format(r'(\d+)', r'(\d+)'),
            checkpoint_name
        )
        epoch = int(match.group(1))
        iteration = int(match.group(2))
        return (torch.load(checkpoint_path), epoch, iteration)
    else:
        return None
utils.py 文件源码 项目:pytorch-adda 作者: corenel 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def init_model(net, restore):
    """Init models with cuda and weights."""
    # init weights of model
    net.apply(init_weights)

    # restore model weights
    if restore is not None and os.path.exists(restore):
        net.load_state_dict(torch.load(restore))
        net.restored = True
        print("Restore model from: {}".format(os.path.abspath(restore)))

    # check if cuda is available
    if torch.cuda.is_available():
        cudnn.benchmark = True
        net.cuda()

    return net
test_cuda.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def test_serialization_array_with_storage(self):
        x = torch.randn(5, 5).cuda()
        y = torch.IntTensor(2, 5).fill_(0).cuda()
        q = [x, y, x, y.storage()]
        with tempfile.NamedTemporaryFile() as f:
            torch.save(q, f)
            f.seek(0)
            q_copy = torch.load(f)
        self.assertEqual(q_copy, q, 0)
        q_copy[0].fill_(5)
        self.assertEqual(q_copy[0], q_copy[2], 0)
        self.assertTrue(isinstance(q_copy[0], torch.cuda.DoubleTensor))
        self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor))
        self.assertTrue(isinstance(q_copy[2], torch.cuda.DoubleTensor))
        self.assertTrue(isinstance(q_copy[3], torch.cuda.IntStorage))
        q_copy[1].fill_(10)
        self.assertTrue(q_copy[3], torch.cuda.IntStorage(10).fill_(10))
test_cuda.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def test_multigpu_serialization_remap(self):
        x = [torch.randn(4, 4).cuda(0), torch.randn(4, 4).cuda(1)]

        def gpu_remap(storage, location):
            if location == 'cuda:1':
                return storage.cuda(0)

        with tempfile.NamedTemporaryFile() as f:
            torch.save(x, f)
            f.seek(0)
            x_copy = torch.load(f, map_location=gpu_remap)

        for original, copy in zip(x, x_copy):
            self.assertEqual(copy, original)
            self.assertIs(type(copy), type(original))
            self.assertEqual(copy.get_device(), 0)
common_nn.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def __call__(self, test_case):
        module = self.constructor(*self.constructor_args)
        input = self._get_input()

        if self.reference_fn is not None:
            out = test_case._forward(module, input)
            if isinstance(out, Variable):
                out = out.data
            ref_input = self._unpack_input(deepcopy(input))
            expected_out = self.reference_fn(ref_input, test_case._get_parameters(module)[0])
            test_case.assertEqual(out, expected_out)

        self.test_noncontig(test_case, module, input)

        # TODO: do this with in-memory files as soon as torch.save will support it
        with TemporaryFile() as f:
            test_case._forward(module, input)
            torch.save(module, f)
            f.seek(0)
            module_copy = torch.load(f)
            test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input))

        self._do_test(test_case, module, input)
test_torch.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 37 收藏 0 点赞 0 评论 0
def test_serialization_map_location(self):
        DATA_URL = 'https://download.pytorch.org/test_data/gpu_tensors.pt'
        data_dir = os.path.join(os.path.dirname(__file__), 'data')
        test_file_path = os.path.join(data_dir, 'gpu_tensors.pt')
        succ = download_file(DATA_URL, test_file_path)
        if not succ:
            warnings.warn(
                "Couldn't download the test file for map_location! "
                "Tests will be incomplete!", RuntimeWarning)
            return

        def map_location(storage, loc):
            return storage

        tensor = torch.load(test_file_path, map_location=map_location)
        self.assertEqual(type(tensor), torch.FloatTensor)
        self.assertEqual(tensor, torch.FloatTensor([[1.0, 2.0], [3.0, 4.0]]))

        tensor = torch.load(test_file_path, map_location={'cuda:0': 'cpu'})
        self.assertEqual(type(tensor), torch.FloatTensor)
        self.assertEqual(tensor, torch.FloatTensor([[1.0, 2.0], [3.0, 4.0]]))
ensamble.py 文件源码 项目:PyTorchText 作者: chenyuntc 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def ensamble(file1,file2,label_path=label_path,     test_data_path=test_data_path,result_csv=None):
    import torch as t 
    import numpy as np
    if result_csv is None:
        import time
        result_csv = time.strftime('%y%m%d_%H%M%S.csv')
    a = t.load(file1)
    b = t.load(file2)
    r = 9.0*a+b
    result = r.topk(5,1)[1]

    index2qid = np.load(test_data_path)['index2qid'].item()
    with open(label_path) as f:   label2qid = json.load(f)['id2label']
    rows = range(result.size(0))
    for ii,item in enumerate(result):
        rows[ii] = [index2qid[ii]] + [label2qid[str(_)] for _ in item ]
    import csv
    with open(result_csv,'w') as f:
        writer = csv.writer(f)
        writer.writerows(rows)
models.py 文件源码 项目:Structured-Self-Attentive-Sentence-Embedding 作者: ExplorerFreda 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def __init__(self, config):
        super(BiLSTM, self).__init__()
        self.drop = nn.Dropout(config['dropout'])
        self.encoder = nn.Embedding(config['ntoken'], config['ninp'])
        self.bilstm = nn.LSTM(config['ninp'], config['nhid'], config['nlayers'], dropout=config['dropout'],
                              bidirectional=True)
        self.nlayers = config['nlayers']
        self.nhid = config['nhid']
        self.pooling = config['pooling']
        self.dictionary = config['dictionary']
#        self.init_weights()
        self.encoder.weight.data[self.dictionary.word2idx['<pad>']] = 0
        if os.path.exists(config['word-vector']):
            print('Loading word vectors from', config['word-vector'])
            vectors = torch.load(config['word-vector'])
            assert vectors[2] >= config['ninp']
            vocab = vectors[0]
            vectors = vectors[1]
            loaded_cnt = 0
            for word in self.dictionary.word2idx:
                if word not in vocab:
                    continue
                real_id = self.dictionary.word2idx[word]
                loaded_id = vocab[word]
                self.encoder.weight.data[real_id] = vectors[loaded_id][:config['ninp']]
                loaded_cnt += 1
            print('%d words from external word vectors loaded.' % loaded_cnt)

    # note: init_range constraints the value of initial weights
validate.py 文件源码 项目:pytorch-semseg 作者: meetshah1995 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def validate(args):

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    loader = data_loader(data_path, split=args.split, is_transform=True, img_size=(args.img_rows, args.img_cols))
    n_classes = loader.n_classes
    valloader = data.DataLoader(loader, batch_size=args.batch_size, num_workers=4)
    running_metrics = runningScore(n_classes)

    # Setup Model
    model = get_model(args.model_path[:args.model_path.find('_')], n_classes)
    state = convert_state_dict(torch.load(args.model_path)['model_state'])
    model.load_state_dict(state)
    model.eval()

    for i, (images, labels) in tqdm(enumerate(valloader)):
        model.cuda()
        images = Variable(images.cuda(), volatile=True)
        labels = Variable(labels.cuda(), volatile=True)

        outputs = model(images)
        pred = outputs.data.max(1)[1].cpu().numpy()
        gt = labels.data.cpu().numpy()

        running_metrics.update(gt, pred)

    score, class_iou = running_metrics.get_scores()

    for k, v in score.items():
        print(k, v)

    for i in range(n_classes):
        print(i, class_iou[i])
saliency.py 文件源码 项目:DeepLearning_PlantDiseases 作者: MarkoArsenovic 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def load_defined_model(path, num_classes,name):
    model = models.__dict__[name](num_classes=num_classes)
    pretrained_state = torch.load(path)
    new_pretrained_state= OrderedDict()

    for k, v in pretrained_state['state_dict'].items():
        layer_name = k.replace("module.", "")
        new_pretrained_state[layer_name] = v

    #Diff
    diff = [s for s in diff_states(model.state_dict(), new_pretrained_state)]
    if(len(diff)!=0):
        print("Mismatch in these layers :", name, ":", [d[0] for d in diff])

    assert len(diff) == 0

    #Merge
    model.load_state_dict(new_pretrained_state)
    return model


#Load the model
occlusion.py 文件源码 项目:DeepLearning_PlantDiseases 作者: MarkoArsenovic 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def load_defined_model(path, num_classes,name):
    model = models.__dict__[name](num_classes=num_classes)
    pretrained_state = torch.load(path)
    new_pretrained_state= OrderedDict()

    for k, v in pretrained_state['state_dict'].items():
        layer_name = k.replace("module.", "")
        new_pretrained_state[layer_name] = v

    #Diff
    diff = [s for s in diff_states(model.state_dict(), new_pretrained_state)]
    if(len(diff)!=0):
        print("Mismatch in these layers :", name, ":", [d[0] for d in diff])

    assert len(diff) == 0

    #Merge
    model.load_state_dict(new_pretrained_state)
    return model


#Load the model
main.py 文件源码 项目:SGAN 作者: YuhangSong 项目源码 文件源码 阅读 37 收藏 0 点赞 0 评论 0
def restore_model():
    print('Trying load models....')
    try:
        netD.load_state_dict(torch.load('{0}/netD.pth'.format(LOGDIR)))
        print('Previous checkpoint for netD founded')
    except Exception, e:
        print('Previous checkpoint for netD unfounded')
    try:
        netG.load_state_dict(torch.load('{0}/netG.pth'.format(LOGDIR)))
        print('Previous checkpoint for netG founded')
    except Exception, e:
        print('Previous checkpoint for netG unfounded')
    print('')
cp_model.py 文件源码 项目:PaintsPytorch 作者: orashi 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def def_netF():
    vgg16 = M.vgg16()
    vgg16.load_state_dict(torch.load('vgg16-397923af.pth'))
    vgg16.features = nn.Sequential(
        *list(vgg16.features.children())[:9]
    )
    for param in vgg16.parameters():
        param.requires_grad = False
    return vgg16.features
dev_model.py 文件源码 项目:PaintsPytorch 作者: orashi 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def def_netF():
    vgg16 = M.vgg16()
    vgg16.load_state_dict(torch.load('vgg16-397923af.pth'))
    vgg16.features = nn.Sequential(
        *list(vgg16.features.children())[:9]
    )
    for param in vgg16.parameters():
        param.requires_grad = False
    return vgg16.features
feat_bn_model.py 文件源码 项目:PaintsPytorch 作者: orashi 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def def_netF():
    vgg16 = M.vgg16_bn()
    vgg16.load_state_dict(torch.load('vgg16_bn-6c64b313.pth'))
    vgg16.features = nn.Sequential(
        *list(vgg16.features.children())[:13]
    )
    for param in vgg16.parameters():
        param.requires_grad = False
    return vgg16.features
pack_model.py 文件源码 项目:PaintsPytorch 作者: orashi 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def def_netF():
    vgg16 = M.vgg16()
    vgg16.load_state_dict(torch.load('vgg16-397923af.pth'))
    vgg16.features = nn.Sequential(
        *list(vgg16.features.children())[:9]
    )
    for param in vgg16.parameters():
        param.requires_grad = False
    return vgg16.features
pro_model.py 文件源码 项目:PaintsPytorch 作者: orashi 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def def_netF():
    vgg16 = M.vgg16()
    vgg16.load_state_dict(torch.load('vgg16-397923af.pth'))
    vgg16.features = nn.Sequential(
        *list(vgg16.features.children())[:9]
    )
    for param in vgg16.parameters():
        param.requires_grad = False
    return vgg16.features
ins_mode.py 文件源码 项目:PaintsPytorch 作者: orashi 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def def_netF():
    vgg16 = M.vgg16()
    vgg16.load_state_dict(torch.load('vgg16-397923af.pth'))
    vgg16.features = nn.Sequential(
        *list(vgg16.features.children())[:9]
    )
    for param in vgg16.parameters():
        param.requires_grad = False
    return vgg16.features


问题


面经


文章

微信
公众号

扫码关注公众号