python类load()的实例源码

segment.py 文件源码 项目:torch_light 作者: ne7ermore 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def __init__(self, model_source="model", cuda=False):
        self.torch = torch.cuda if cuda else torch
        self.cuda = cuda
        if self.cuda:
            model_source = torch.load(model_source)
        else:
            model_source = torch.load(model_source, map_location=lambda storage, loc: storage)

        self.src_dict = model_source["src_dict"]
        self.trains_score = model_source["trains_score"]
        self.args = args = model_source["settings"]

        model = BiLSTM_Cut(args)
        model.load_state_dict(model_source['model'])

        if self.cuda:
            model = model.cuda()
            model.prob_projection = nn.Softmax().cuda()
        else:
            model = model.cpu()
            model.prob_projection = nn.Softmax().cpu()

        self.model = model.eval()
demo.py 文件源码 项目:pytorch.rl.learning 作者: moskomule 项目源码 文件源码 阅读 35 收藏 0 点赞 0 评论 0
def main(env, weight_path, epsilon):
    env = make_atari(env)
    q_function = DQN(env.action_space.n)
    q_function.load_state_dict(torch.load(weight_path))

    done = False
    state = env.reset()
    step = 1
    sleep(2)
    while not done:
        env.render()
        if random() <= epsilon:
            action = randrange(0, env.action_space.n)
        else:
            state = variable(to_tensor(state).unsqueeze(0))
            action = q_function(state).data.view(-1).max(dim=0)[1].sum()

        state, reward, done, info = env.step(action)
        print(f"[step: {step:>5}] [reward: {reward:>5}]")
        step += 1
    sleep(2)
test_cuda.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def test_serialization(self):
        x = torch.randn(5, 5).cuda()
        y = torch.IntTensor(2, 5).fill_(0).cuda()
        q = [x, y, x, y.storage()]
        with tempfile.NamedTemporaryFile() as f:
            torch.save(q, f)
            f.seek(0)
            q_copy = torch.load(f)
        self.assertEqual(q_copy, q, 0)
        q_copy[0].fill_(5)
        self.assertEqual(q_copy[0], q_copy[2], 0)
        self.assertTrue(isinstance(q_copy[0], torch.cuda.DoubleTensor))
        self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor))
        self.assertTrue(isinstance(q_copy[2], torch.cuda.DoubleTensor))
        self.assertTrue(isinstance(q_copy[3], torch.cuda.IntStorage))
        q_copy[1].fill_(10)
        self.assertTrue(q_copy[3], torch.cuda.IntStorage(10).fill_(10))
common_nn.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 42 收藏 0 点赞 0 评论 0
def __call__(self, test_case):
        module = self.constructor(*self.constructor_args)
        input = self._get_input()

        if self.reference_fn is not None:
            out = test_case._forward(module, input)
            if isinstance(out, Variable):
                out = out.data
            ref_input = self._unpack_input(deepcopy(input))
            expected_out = self.reference_fn(ref_input, test_case._get_parameters(module)[0])
            test_case.assertEqual(out, expected_out)

        # TODO: do this with in-memory files as soon as torch.save will support it
        with TemporaryFile() as f:
            test_case._forward(module, input)
            torch.save(module, f)
            f.seek(0)
            module_copy = torch.load(f)
            test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input))

        self._do_test(test_case, module, input)
test_torch.py 文件源码 项目:pytorch-dist 作者: apaszke 项目源码 文件源码 阅读 41 收藏 0 点赞 0 评论 0
def test_serialization(self):
        a = [torch.randn(5, 5).float() for i in range(2)]
        b = [a[i % 2] for i in range(4)]
        b += [a[0].storage()]
        b += [a[0].storage()[1:4]]
        for use_name in (False, True):
            with tempfile.NamedTemporaryFile() as f:
                handle = f if not use_name else f.name
                torch.save(b, handle)
                f.seek(0)
                c = torch.load(handle)
            self.assertEqual(b, c, 0)
            self.assertTrue(isinstance(c[0], torch.FloatTensor))
            self.assertTrue(isinstance(c[1], torch.FloatTensor))
            self.assertTrue(isinstance(c[2], torch.FloatTensor))
            self.assertTrue(isinstance(c[3], torch.FloatTensor))
            self.assertTrue(isinstance(c[4], torch.FloatStorage))
            c[0].fill_(10)
            self.assertEqual(c[0], c[2], 0)
            self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
            c[1].fill_(20)
            self.assertEqual(c[1], c[3], 0)
            self.assertEqual(c[4], c[5][1:4], 0)
torch.py 文件源码 项目:emu 作者: mlosch 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def _load_mean_std(handle):
        """
        Loads mean/std values from a .t7/.npy file or returns the identity if already a numpy array.
        Parameters
        ----------
        handle : Can be either a numpy array or a filepath as string

        Returns
        ----------
        mean/std : Numpy array expressing mean/std
        """
        if type(handle) == str:
            if handle.endswith('.t7'):
                return load_lua(handle).numpy()
            elif handle.endswith('.npy'):
                return np.load(handle)
            else:
                return torch.load(handle).numpy()
        elif type(handle) == np.ndarray:
            return handle
generate.py 文件源码 项目:torch_light 作者: ne7ermore 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def __init__(self, model=None, model_source=None, src_dict=None, args=None):
        assert model is not None or model_source is not None

        if model is None:
            model_source = torch.load(model_source, map_location=lambda storage, loc: storage)
            self.dict = model_source["src_dict"]
            self.args = model_source["settings"]
            model = Model(self.args)
            model.load_state_dict(model_source['model'])
        else:
            self.dict = src_dict
            self.args = args

        self.num_directions = 2 if self.args.bidirectional else 1
        self.idx2word = {v: k for k, v in self.dict.items()}
        self.model = model.eval()
transform.py 文件源码 项目:torch_light 作者: ne7ermore 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def __init__(self, model_source, cuda=False, beam_size=3):
        self.torch = torch.cuda if cuda else torch
        self.cuda = cuda
        self.beam_size = beam_size

        if self.cuda:
            model_source = torch.load(model_source)
        else:
            model_source = torch.load(model_source, map_location=lambda storage, loc: storage)
        self.src_dict = model_source["src_dict"]
        self.tgt_dict = model_source["tgt_dict"]
        self.src_idx2word = {v: k for k, v in model_source["tgt_dict"].items()}
        self.args = args = model_source["settings"]
        model = Transformer(args)
        model.load_state_dict(model_source['model'])

        if self.cuda: model = model.cuda()
        else: model = model.cpu()
        self.model = model.eval()
predict.py 文件源码 项目:torch_light 作者: ne7ermore 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def __init__(self, model_source, cuda=False, beam_size=3):
        self.torch = torch.cuda if cuda else torch
        self.cuda = cuda
        self.jb = Jieba("./segmenter_dicts", useSynonym=True, HMM=False)
        self.swf = StopwordFilter("./segmenter_dicts/stopwords.txt")

        model_source = torch.load(model_source)
        self.src_dict = model_source["src_dict"]
        self.tgt_dict = model_source["tgt_dict"]
        self.src_idx2ind = {v: k for k, v in model_source["tgt_dict"].items()}
        self.args = args = model_source["settings"]
        model = CNN_Ranking(args)
        model.load_state_dict(model_source['model'])

        if self.cuda:
            model = model.cuda()
        else:
            model = model.cpu()
        self.model = model.eval()
predict.py 文件源码 项目:torch_light 作者: ne7ermore 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def __init__(self, model_source, cuda=False):
        self.torch = torch.cuda if cuda else torch
        self.cuda = cuda
        if self.cuda:
            model_source = torch.load(model_source)
        else:
            model_source = torch.load(model_source, map_location=lambda storage, loc: storage)

        self.src_dict = model_source["src_dict"]
        self.trains_score = model_source["trains_score"]
        self.args = args = model_source["settings"]

        model = BiLSTM_CRF_Size(args)
        model.load_state_dict(model_source['model'])

        if self.cuda:
            model = model.cuda()
            model.prob_projection = nn.Softmax().cuda()
        else:
            model = model.cpu()
            model.prob_projection = nn.Softmax().cpu()

        self.model = model.eval()
predict.py 文件源码 项目:KagglePlanetPytorch 作者: Mctigger 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def predict_kfold(model_name, pre_transforms=[]):
    model = locate(model_name + '.generate_model')()
    random_state = locate(model_name + '.random_state')
    print('Random state: {}'.format(random_state))

    labels_df = labels.get_labels_df()
    kf = sklearn.model_selection.KFold(n_splits=5, shuffle=True, random_state=random_state)
    split = kf.split(labels_df)

    for i, (train_idx, val_idx) in enumerate(split):
        split_name = model_name + '-split_' + str(i)
        best_epoch = util.find_epoch_val(split_name)
        print('Using epoch {} for predictions'.format(best_epoch))
        epoch_name = split_name + '-epoch_' + str(best_epoch)
        train = labels_df.ix[train_idx]
        val = labels_df.ix[val_idx]
        state = torch.load(os.path.join(paths.models, split_name, epoch_name))

        predict_model(model, state, train, val, output_file=split_name, pre_transforms=pre_transforms)
data_loader.py 文件源码 项目:multiNLI_encoder 作者: easonnie 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def load_data(data_root, embd_file, reseversed=True, batch_sizes=(32, 32, 32), device=-1):
    if reseversed:
        testl_field = RParsedTextLField()
    else:
        testl_field = ParsedTextLField()

    transitions_field = datasets.snli.ShiftReduceField()
    y_field = data.Field(sequential=False)

    train, dev, test = datasets.SNLI.splits(testl_field, y_field, transitions_field, root=data_root)
    testl_field.build_vocab(train, dev, test)
    y_field.build_vocab(train)

    testl_field.vocab.vectors = torch.load(embd_file)

    train_iter, dev_iter, test_iter = data.Iterator.splits(
        (train, dev, test), batch_sizes=batch_sizes, device=device, shuffle=False)

    return train_iter, dev_iter, test_iter, testl_field.vocab.vectors
vctk.py 文件源码 项目:audio 作者: pytorch 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def __init__(self, root, downsample=True, transform=None, target_transform=None, download=False, dev_mode=False):
        self.root = os.path.expanduser(root)
        self.downsample = downsample
        self.transform = transform
        self.target_transform = target_transform
        self.dev_mode = dev_mode
        self.data = []
        self.labels = []
        self.chunk_size = 1000
        self.num_samples = 0
        self.max_len = 0
        self.cached_pt = 0

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')
        self._read_info()
        self.data, self.labels = torch.load(os.path.join(
            self.root, self.processed_folder, "vctk_{:04d}.pt".format(self.cached_pt)))
vctk.py 文件源码 项目:audio 作者: pytorch 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        if self.cached_pt != index // self.chunk_size:
            self.cached_pt = int(index // self.chunk_size)
            self.data, self.labels = torch.load(os.path.join(
                self.root, self.processed_folder, "vctk_{:04d}.pt".format(self.cached_pt)))
        index = index % self.chunk_size
        audio, target = self.data[index], self.labels[index]

        if self.transform is not None:
            audio = self.transform(audio)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return audio, target
yesno.py 文件源码 项目:audio 作者: pytorch 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def __init__(self, root, transform=None, target_transform=None, download=False, dev_mode=False):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.dev_mode = dev_mode
        self.data = []
        self.labels = []
        self.num_samples = 0
        self.max_len = 0

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')
        self.data, self.labels = torch.load(os.path.join(
            self.root, self.processed_folder, self.processed_file))
models.py 文件源码 项目:MIL.pytorch 作者: gujiuxiang 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def build_mil(opt):
    opt.n_gpus = getattr(opt, 'n_gpus', 1)

    if 'resnet101' in opt.model:
        mil_model = resnet_mil(opt)
    else:
        mil_model = vgg_mil(opt)

    if opt.n_gpus>1:
        print('Construct multi-gpu model ...')
        model = nn.DataParallel(mil_model, device_ids=opt.gpus, dim=0)
    else:
        model = mil_model
    # check compatibility if training is continued from previously saved model
    if len(opt.start_from) != 0:
        # check if all necessary files exist
        assert os.path.isdir(opt.start_from), " %s must be a a path" % opt.start_from
        lm_info_path = os.path.join(opt.start_from, os.path.basename(opt.start_from) + '.infos-best.pkl')
        lm_pth_path = os.path.join(opt.start_from, os.path.basename(opt.start_from) + '.model-best.pth')
        assert os.path.isfile(lm_info_path), "infos.pkl file does not exist in path %s" % opt.start_from
        model.load_state_dict(torch.load(lm_pth_path))
    model.cuda()
    model.train()  # Assure in training mode
    return model
resnet_mil.py 文件源码 项目:MIL.pytorch 作者: gujiuxiang 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def __init__(self, opt):
        super(resnet_mil, self).__init__()
        import model.resnet as resnet
        resnet = resnet.resnet101()
        resnet.load_state_dict(torch.load('/media/jxgu/d2tb/model/resnet/resnet101.pth'))
        self.conv = torch.nn.Sequential()
        self.conv.add_module("conv1", resnet.conv1)
        self.conv.add_module("bn1", resnet.bn1)
        self.conv.add_module("relu", resnet.relu)
        self.conv.add_module("maxpool", resnet.maxpool)
        self.conv.add_module("layer1", resnet.layer1)
        self.conv.add_module("layer2", resnet.layer2)
        self.conv.add_module("layer3", resnet.layer3)
        self.conv.add_module("layer4", resnet.layer4)
        self.l1 = nn.Sequential(nn.Linear(2048, 1000),
                                nn.ReLU(True),
                                nn.Dropout(0.5))
        self.att_size = 7
        self.pool_mil = nn.MaxPool2d(kernel_size=self.att_size, stride=0)
graph.py 文件源码 项目:ParlAI 作者: facebookresearch 项目源码 文件源码 阅读 40 收藏 0 点赞 0 评论 0
def load_graph(self, fname):
        if fname != '':
            path =  os.path.join(self._opt['datapath'], 'graph_world2')
            fname = path + '/' + fname + '.gw2'
        else:
            fname = self._save_fname
        if not os.path.isfile(fname):
            print("[graph file not found: " + fname + ']')
            return
        print("[loading graph: " + fname + ']')
        members = [attr for attr in dir(self) if not callable(getattr(self, attr))
                   and (not attr.startswith("__")) and (attr.startswith("_"))]
        with open(fname, 'rb') as read:
            model = torch.load(read)
        for m in members:
            if m in model:
                setattr(self, m, model[m])
            else:
                print("[ loading: " + m + " is missing in file ]")
        self._save_fname = fname
utils.py 文件源码 项目:ParlAI 作者: facebookresearch 项目源码 文件源码 阅读 45 收藏 0 点赞 0 评论 0
def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device=None):
    if not os.path.exists(filename):
        return None, []
    if cuda_device is None:
        state = torch.load(filename)
    else:
        state = torch.load(
            filename,
            map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device))
        )
    state = _upgrade_state_dict(state)

    # load model parameters
    model.load_state_dict(state['model'])

    # only load optimizer and lr_scheduler if they match with the checkpoint
    optim_history = state['optimizer_history']
    last_optim = optim_history[-1]
    if last_optim['criterion_name'] == criterion.__class__.__name__:
        optimizer.load_state_dict(state['last_optimizer_state'])
        lr_scheduler.best = last_optim['best_loss']

    return state['extra_state'], optim_history
ucf101.py 文件源码 项目:c3d_pytorch 作者: whitesnowdrop 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.train = train                              # training set or test set

        if download:
            self.download()

        #if not self._check_exists():
        #    raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(root, self.processed_folder, self.training_file))
        else:
            self.test_data, self.test_labels = torch.load(
                os.path.join(root, self.processed_folder, self.test_file))
solver.py 文件源码 项目:pytorch-tutorial 作者: yunjey 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def sample(self):

        # Load trained parameters 
        g_path = os.path.join(self.model_path, 'generator-%d.pkl' %(self.num_epochs))
        d_path = os.path.join(self.model_path, 'discriminator-%d.pkl' %(self.num_epochs))
        self.generator.load_state_dict(torch.load(g_path))
        self.discriminator.load_state_dict(torch.load(d_path))
        self.generator.eval()
        self.discriminator.eval()

        # Sample the images
        noise = self.to_variable(torch.randn(self.sample_size, self.z_dim))
        fake_images = self.generator(noise)
        sample_path = os.path.join(self.sample_path, 'fake_samples-final.png')
        torchvision.utils.save_image(self.denorm(fake_images.data), sample_path, nrow=12)

        print("Saved sampled images to '%s'" %sample_path)
calc_plex.py 文件源码 项目:Tree-LSTM-LM 作者: vgene 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def main():
    import sys
    reload(sys)
    sys.setdefaultencoding("utf-8")
    argparser = argparse.ArgumentParser()
    argparser.add_argument('--model', type=str)
    argparser.add_argument('--test_file', type=str)
    argparser.add_argument('--cuda', action='store_true')
    args = argparser.parse_args()

    model = torch.load(args.model)
    print(model.vocab_size)
    batch_size = 1000
    tester = Tester(args.test_file, batch_size, model.mapping)
    perplexity = tester.calc_perplexity(model, cuda=args.cuda)
    print("Test File: {}, Perplexity:{}".format(args.test_file, perplexity))
main_hyperparams.py 文件源码 项目:cnn-lstm-bilstm-deepcnn-clstm-in-pytorch 作者: bamtercelboo 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def mrs_two_mui(path, train_name, dev_name, test_name, char_data, text_field, label_field, static_text_field, static_label_field, **kargs):
    train_data, dev_data, test_data = mydatasets_self_two.MR.splits(path, train_name, dev_name, test_name,
                                                                    char_data, text_field, label_field)
    static_train_data, static_dev_data, static_test_data = mydatasets_self_two.MR.splits(path, train_name, dev_name,
                                                                                         test_name,
                                                                                         char_data, static_text_field,
                                                                                         static_label_field)
    print("len(train_data) {} ".format(len(train_data)))
    print("len(train_data) {} ".format(len(static_train_data)))
    text_field.build_vocab(train_data, min_freq=args.min_freq)
    label_field.build_vocab(train_data)
    static_text_field.build_vocab(static_train_data, static_dev_data, static_test_data, min_freq=args.min_freq)
    static_label_field.build_vocab(static_train_data, static_dev_data, static_test_data)
    train_iter, dev_iter, test_iter = data.Iterator.splits(
                                        (train_data, dev_data, test_data),
                                        batch_sizes=(args.batch_size,
                                                     len(dev_data),
                                                     len(test_data)),
                                        **kargs)
    return train_iter, dev_iter, test_iter



# load five-classification data
main_hyperparams.py 文件源码 项目:cnn-lstm-bilstm-deepcnn-clstm-in-pytorch 作者: bamtercelboo 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def mrs_five_mui(path, train_name, dev_name, test_name, char_data, text_field, label_field, static_text_field, static_label_field, **kargs):
    train_data, dev_data, test_data = mydatasets_self_five.MR.splits(path, train_name, dev_name, test_name,
                                                                     char_data, text_field, label_field)
    static_train_data, static_dev_data, static_test_data = mydatasets_self_five.MR.splits(path, train_name, dev_name,
                                                                                          test_name,
                                                                                          char_data,
                                                                                         static_text_field,
                                                                                          static_label_field)
    print("len(train_data) {} ".format(len(train_data)))
    print("len(train_data) {} ".format(len(static_train_data)))
    text_field.build_vocab(train_data, min_freq=args.min_freq)
    label_field.build_vocab(train_data)
    static_text_field.build_vocab(static_train_data, static_dev_data, static_test_data, min_freq=args.min_freq)
    static_label_field.build_vocab(static_train_data, static_dev_data, static_test_data)
    train_iter, dev_iter, test_iter = data.Iterator.splits(
                                        (train_data, dev_data, test_data),
                                        batch_sizes=(args.batch_size,
                                                     len(dev_data),
                                                     len(test_data)),
                                        **kargs)
    return train_iter, dev_iter, test_iter


# load MR dataset
utils.py 文件源码 项目:pytorch-arda 作者: corenel 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def init_model(net, restore):
    """Init models with cuda and weights."""
    # init weights of model
    net.apply(init_weights)

    # restore model weights
    if restore is not None and os.path.exists(restore):
        net.load_state_dict(torch.load(restore))
        net.restored = True
        print("Restore model from: {}".format(os.path.abspath(restore)))

    # check if cuda is available
    if torch.cuda.is_available():
        cudnn.benchmark = True
        net.cuda()

    return net
pytorch_model.py 文件源码 项目:char-rnn-text-generation 作者: yxtay 项目源码 文件源码 阅读 43 收藏 0 点赞 0 评论 0
def generate_main(args):
    """
    generates text from trained model specified in args.
    main method for generate subcommand.
    """
    # load model
    inference_model = Model.load(args.checkpoint_path)

    # create seed if not specified
    if args.seed is None:
        with open(args.text_path) as f:
            text = f.read()
        seed = generate_seed(text)
        logger.info("seed sequence generated from %s.", args.text_path)
    else:
        seed = args.seed

    return generate_text(inference_model, seed, args.length, args.top_n)
embeddings.py 文件源码 项目:treehopper 作者: tomekkorbak 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def load_word_vectors(embeddings_path):
    if os.path.isfile(embeddings_path + '.pth') and \
            os.path.isfile(embeddings_path + '.vocab'):
        print('==> File found, loading to memory')
        vectors = torch.load(embeddings_path + '.pth')
        vocab = Vocab(filename=embeddings_path + '.vocab')
        return vocab, vectors
    if os.path.isfile(embeddings_path + '.model'):
        model = KeyedVectors.load(embeddings_path + ".model")
    if os.path.isfile(embeddings_path + '.vec'):
        model = FastText.load_word2vec_format(embeddings_path + '.vec')
    list_of_tokens = model.vocab.keys()
    vectors = torch.zeros(len(list_of_tokens), model.vector_size)
    with open(embeddings_path + '.vocab', 'w', encoding='utf-8') as f:
        for token in list_of_tokens:
            f.write(token+'\n')
    vocab = Vocab(filename=embeddings_path + '.vocab')
    for index, word in enumerate(list_of_tokens):
        vectors[index, :] = torch.from_numpy(model[word])
    return vocab, vectors
utils.py 文件源码 项目:fairseq-py 作者: facebookresearch 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def load_model_state(filename, model, cuda_device=None):
    if not os.path.exists(filename):
        return None, [], None
    if cuda_device is None:
        state = torch.load(filename)
    else:
        state = torch.load(
            filename,
            map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device))
        )
    state = _upgrade_state_dict(state)
    state['model'] = model.upgrade_state_dict(state['model'])

    # load model parameters
    try:
        model.load_state_dict(state['model'])
    except:
        raise Exception('Cannot load model parameters from checkpoint, '
                        'please ensure that the architectures match')

    return state['extra_state'], state['optimizer_history'], state['last_optimizer_state']
utils.py 文件源码 项目:3DGAN-Pytorch 作者: rimchang 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def read_pickle(path, G, G_solver, D_, D_solver):
    try:

        files = os.listdir(path)
        file_list = [int(file.split('_')[-1].split('.')[0]) for file in files]
        file_list.sort()
        recent_iter = str(file_list[-1])
        print(recent_iter, path)

        with open(path + "/G_" + recent_iter + ".pkl", "rb") as f:
            G.load_state_dict(torch.load(f))
        with open(path + "/G_optim_" + recent_iter + ".pkl", "rb") as f:
            G_solver.load_state_dict(torch.load(f))
        with open(path + "/D_" + recent_iter + ".pkl", "rb") as f:
            D_.load_state_dict(torch.load(f))
        with open(path + "/D_optim_" + recent_iter + ".pkl", "rb") as f:
            D_solver.load_state_dict(torch.load(f))


    except Exception as e:

        print("fail try read_pickle", e)
motion_cnn.py 文件源码 项目:two-stream-action-recognition 作者: jeffreyhuang1 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def resume_and_evaluate(self):
        if self.resume:
            if os.path.isfile(self.resume):
                print("==> loading checkpoint '{}'".format(self.resume))
                checkpoint = torch.load(self.resume)
                self.start_epoch = checkpoint['epoch']
                self.best_prec1 = checkpoint['best_prec1']
                self.model.load_state_dict(checkpoint['state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer'])
                print("==> loaded checkpoint '{}' (epoch {}) (best_prec1 {})"
                  .format(self.resume, checkpoint['epoch'], self.best_prec1))
            else:
                print("==> no checkpoint found at '{}'".format(self.resume))
        if self.evaluate:
            self.epoch=0
            prec1, val_loss = self.validate_1epoch()
            return


问题


面经


文章

微信
公众号

扫码关注公众号