python类Dataset()的实例源码

data.py 文件源码 项目:vsepp 作者: fartashf 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def get_test_loader(split_name, data_name, vocab, crop_size, batch_size,
                    workers, opt):
    dpath = os.path.join(opt.data_path, data_name)
    if opt.data_name.endswith('_precomp'):
        test_loader = get_precomp_loader(dpath, split_name, vocab, opt,
                                         batch_size, False, workers)
    else:
        # Build Dataset Loader
        roots, ids = get_paths(dpath, data_name, opt.use_restval)

        transform = get_transform(data_name, split_name, opt)
        test_loader = get_loader_single(opt.data_name, split_name,
                                        roots[split_name]['img'],
                                        roots[split_name]['cap'],
                                        vocab, transform, ids=ids[split_name],
                                        batch_size=batch_size, shuffle=False,
                                        num_workers=workers,
                                        collate_fn=collate_fn)

    return test_loader
data_loader.py 文件源码 项目:make_dataset 作者: hyzhan 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def __init__(self, audio_conf, manifest_filepath, labels, normalize=False, augment=False):
        """
        Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by
        a comma. Each new line is a different sample. Example below:

        /path/to/audio.wav,/path/to/audio.txt
        ...

        :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds
        :param manifest_filepath: Path to manifest csv as describe above
        :param labels: String containing all the possible characters to map to
        :param normalize: Apply standard mean and deviation normalization to audio tensor
        :param augment(default False):  Apply random tempo and gain perturbations
        """
        with open(manifest_filepath) as f:
            ids = f.readlines()
        ids = [x.strip().split(',') for x in ids]
        self.ids = ids
        self.size = len(ids)
        self.labels_map = dict([(labels[i], i) for i in range(len(labels))])
        super(SpectrogramDataset, self).__init__(audio_conf, normalize, augment)
vctk.py 文件源码 项目:audio 作者: pytorch 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def __init__(self, root, downsample=True, transform=None, target_transform=None, download=False, dev_mode=False):
        self.root = os.path.expanduser(root)
        self.downsample = downsample
        self.transform = transform
        self.target_transform = target_transform
        self.dev_mode = dev_mode
        self.data = []
        self.labels = []
        self.chunk_size = 1000
        self.num_samples = 0
        self.max_len = 0
        self.cached_pt = 0

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')
        self._read_info()
        self.data, self.labels = torch.load(os.path.join(
            self.root, self.processed_folder, "vctk_{:04d}.pt".format(self.cached_pt)))
yesno.py 文件源码 项目:audio 作者: pytorch 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def __init__(self, root, transform=None, target_transform=None, download=False, dev_mode=False):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.dev_mode = dev_mode
        self.data = []
        self.labels = []
        self.num_samples = 0
        self.max_len = 0

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')
        self.data, self.labels = torch.load(os.path.join(
            self.root, self.processed_folder, self.processed_file))
ucf101.py 文件源码 项目:c3d_pytorch 作者: whitesnowdrop 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.train = train                              # training set or test set

        if download:
            self.download()

        #if not self._check_exists():
        #    raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(root, self.processed_folder, self.training_file))
        else:
            self.test_data, self.test_labels = torch.load(
                os.path.join(root, self.processed_folder, self.test_file))
data_loader.py 文件源码 项目:seq2seq-dataloader 作者: yunjey 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def get_loader(src_path, trg_path, src_word2id, trg_word2id, batch_size=100):
    """Returns data loader for custom dataset.

    Args:
        src_path: txt file path for source domain.
        trg_path: txt file path for target domain.
        src_word2id: word-to-id dictionary (source domain).
        trg_word2id: word-to-id dictionary (target domain).
        batch_size: mini-batch size.

    Returns:
        data_loader: data loader for custom dataset.
    """
    # build a custom dataset
    dataset = Dataset(src_path, trg_path, src_word2id, trg_word2id)

    # data loader for custome dataset
    # this will return (src_seqs, src_lengths, trg_seqs, trg_lengths) for each iteration
    # please see collate_fn for details
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              collate_fn=collate_fn)

    return data_loader
motion_dataloader.py 文件源码 项目:two-stream-action-recognition 作者: jeffreyhuang1 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def __getitem__(self, idx):
        #print ('mode:',self.mode,'calling Dataset:__getitem__ @ idx=%d'%idx)

        if self.mode == 'train':
            self.video, nb_clips = self.keys[idx].split('-')
            self.clips_idx = random.randint(1,int(nb_clips))
        elif self.mode == 'val':
            self.video,self.clips_idx = self.keys[idx].split('-')
        else:
            raise ValueError('There are only train and val mode')

        label = self.values[idx]
        label = int(label)-1 
        data = self.stackopf()

        if self.mode == 'train':
            sample = (data,label)
        elif self.mode == 'val':
            sample = (self.video,data,label)
        else:
            raise ValueError('There are only train and val mode')
        return sample
cub200.py 文件源码 项目:generative_zoo 作者: DL-IT 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found. You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.training_file))
        else:
            self.test_data, self.test_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.test_file))
data_loader.py 文件源码 项目:deepspeech.pytorch 作者: SeanNaren 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def __init__(self, audio_conf, manifest_filepath, labels, normalize=False, augment=False):
        """
        Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by
        a comma. Each new line is a different sample. Example below:

        /path/to/audio.wav,/path/to/audio.txt
        ...

        :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds
        :param manifest_filepath: Path to manifest csv as describe above
        :param labels: String containing all the possible characters to map to
        :param normalize: Apply standard mean and deviation normalization to audio tensor
        :param augment(default False):  Apply random tempo and gain perturbations
        """
        with open(manifest_filepath) as f:
            ids = f.readlines()
        ids = [x.strip().split(',') for x in ids]
        self.ids = ids
        self.size = len(ids)
        self.labels_map = dict([(labels[i], i) for i in range(len(labels))])
        super(SpectrogramDataset, self).__init__(audio_conf, normalize, augment)
data_loader.py 文件源码 项目:make_dataset 作者: hyzhan 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def __init__(self, audio_conf, manifest_filepath, labels, normalize=False, augment=False):
        """
        Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by
        a comma. Each new line is a different sample. Example below:

        /path/to/audio.wav,/path/to/audio.txt
        ...

        :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds
        :param manifest_filepath: Path to manifest csv as describe above
        :param labels: String containing all the possible characters to map to
        :param normalize: Apply standard mean and deviation normalization to audio tensor
        :param augment(default False):  Apply random tempo and gain perturbations
        """
        with open(manifest_filepath) as f:
            ids = f.readlines()
        ids = [x.strip().split(',') for x in ids]
        self.ids = ids
        self.size = len(ids)
        self.labels_map = dict([(labels[i], i) for i in range(len(labels))])
        super(SpectrogramDataset, self).__init__(audio_conf, normalize, augment)
test_datasets.py 文件源码 项目:nnmnkwii 作者: r9y9 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def test_sequence_wise_torch_data_loader():
    import torch
    from torch.utils import data as data_utils

    X, Y = _get_small_datasets(padded=False)

    class TorchDataset(data_utils.Dataset):
        def __init__(self, X, Y):
            self.X = X
            self.Y = Y

        def __getitem__(self, idx):
            return torch.from_numpy(self.X[idx]), torch.from_numpy(self.Y[idx])

        def __len__(self):
            return len(self.X)

    def __test(X, Y, batch_size):
        dataset = TorchDataset(X, Y)
        loader = data_utils.DataLoader(
            dataset, batch_size=batch_size, num_workers=1, shuffle=True)
        for idx, (x, y) in enumerate(loader):
            assert len(x.shape) == len(y.shape)
            assert len(x.shape) == 3
            print(idx, x.shape, y.shape)

    # Test with batch_size = 1
    yield __test, X, Y, 1
    # Since we have variable length frames, batch size larger than 1 causes
    # runtime error.
    yield raises(RuntimeError)(__test), X, Y, 2

    # For padded dataset, which can be reprensented by (N, T^max, D), batchsize
    # can be any number.
    X, Y = _get_small_datasets(padded=True)
    yield __test, X, Y, 1
    yield __test, X, Y, 2
test_datasets.py 文件源码 项目:nnmnkwii 作者: r9y9 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def test_frame_wise_torch_data_loader():
    import torch
    from torch.utils import data as data_utils

    X, Y = _get_small_datasets(padded=False)

    # Since torch's Dataset (and Chainer, and maybe others) assumes dataset has
    # fixed size length, i.e., implements `__len__` method, we need to know
    # number of frames for each utterance.
    # Sum of the number of frames is the dataset size for frame-wise iteration.
    lengths = np.array([len(x) for x in X], dtype=np.int)

    # For the above reason, we need to explicitly give the number of frames.
    X = MemoryCacheFramewiseDataset(X, lengths, cache_size=len(X))
    Y = MemoryCacheFramewiseDataset(Y, lengths, cache_size=len(Y))

    class TorchDataset(data_utils.Dataset):
        def __init__(self, X, Y):
            self.X = X
            self.Y = Y

        def __getitem__(self, idx):
            return torch.from_numpy(self.X[idx]), torch.from_numpy(self.Y[idx])

        def __len__(self):
            return len(self.X)

    def __test(X, Y, batch_size):
        dataset = TorchDataset(X, Y)
        loader = data_utils.DataLoader(
            dataset, batch_size=batch_size, num_workers=1, shuffle=True)
        for idx, (x, y) in enumerate(loader):
            assert len(x.shape) == 2
            assert len(y.shape) == 2

    yield __test, X, Y, 128
    yield __test, X, Y, 256
usps.py 文件源码 项目:pytorch-arda 作者: corenel 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def __init__(self, root, train=True, transform=None, download=False):
        """Init USPS dataset."""
        # init params
        self.root = os.path.expanduser(root)
        self.filename = "usps_28x28.pkl"
        self.train = train
        # Num of Train = 7438, Num ot Test 1860
        self.transform = transform
        self.dataset_size = None

        # download dataset.
        if download:
            self.download()
        if not self._check_exists():
            raise RuntimeError("Dataset not found." +
                               " You can use download=True to download it")

        self.train_data, self.train_labels = self.load_samples()
        if self.train:
            total_num_samples = self.train_labels.shape[0]
            indices = np.arange(total_num_samples)
            np.random.shuffle(indices)
            self.train_data = self.train_data[indices[0:self.dataset_size], ::]
            self.train_labels = self.train_labels[indices[0:self.dataset_size]]
        self.train_data *= 255.0
        self.train_data = self.train_data.transpose(
            (0, 2, 3, 1))  # convert to HWC
data_loader.py 文件源码 项目:sketchnet 作者: jtoy 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def __init__(self, parent_ds, offset, length):
        self.parent_ds = parent_ds
        self.offset = offset
        self.length = length
        assert len(parent_ds)>=offset+length, Exception("Parent Dataset not long enough")
        super(PartialDataset, self).__init__()
usps.py 文件源码 项目:pytorch-adda 作者: corenel 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def __init__(self, root, train=True, transform=None, download=False):
        """Init USPS dataset."""
        # init params
        self.root = os.path.expanduser(root)
        self.filename = "usps_28x28.pkl"
        self.train = train
        # Num of Train = 7438, Num ot Test 1860
        self.transform = transform
        self.dataset_size = None

        # download dataset.
        if download:
            self.download()
        if not self._check_exists():
            raise RuntimeError("Dataset not found." +
                               " You can use download=True to download it")

        self.train_data, self.train_labels = self.load_samples()
        if self.train:
            total_num_samples = self.train_labels.shape[0]
            indices = np.arange(total_num_samples)
            np.random.shuffle(indices)
            self.train_data = self.train_data[indices[0:self.dataset_size], ::]
            self.train_labels = self.train_labels[indices[0:self.dataset_size]]
        self.train_data *= 255.0
        self.train_data = self.train_data.transpose(
            (0, 2, 3, 1))  # convert to HWC
test_dataloader.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def test_numpy(self):
        import numpy as np

        class TestDataset(torch.utils.data.Dataset):
            def __getitem__(self, i):
                return np.ones((2, 3, 4)) * i

            def __len__(self):
                return 1000

        loader = DataLoader(TestDataset(), batch_size=12)
        batch = next(iter(loader))
        self.assertIsInstance(batch, torch.DoubleTensor)
        self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4]))
test_dataloader.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def test_numpy_scalars(self):
        import numpy as np

        class ScalarDataset(torch.utils.data.Dataset):
            def __init__(self, dtype):
                self.dtype = dtype

            def __getitem__(self, i):
                return self.dtype()

            def __len__(self):
                return 4

        dtypes = {
            np.float64: torch.DoubleTensor,
            np.float32: torch.FloatTensor,
            np.float16: torch.HalfTensor,
            np.int64: torch.LongTensor,
            np.int32: torch.IntTensor,
            np.int16: torch.ShortTensor,
            np.int8: torch.CharTensor,
            np.uint8: torch.ByteTensor,
        }
        for dt, tt in dtypes.items():
            dset = ScalarDataset(dt)
            loader = DataLoader(dset, batch_size=2)
            batch = next(iter(loader))
            self.assertIsInstance(batch, tt)
dataset.1.py 文件源码 项目:PyTorchText 作者: chenyuntc 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def __init__(self,train_root,labels_file,type_='char'):
        '''
        Dataset('/mnt/7/zhihu/ieee_zhihu_cup/train.npz','/mnt/7/zhihu/ieee_zhihu_cup/a.json')
        '''
        import json
        with open(labels_file) as f:
            labels_ = json.load(f)

        # embedding_d = np.load(embedding_root)['vector']
        question_d = np.load(train_root)

            # all_data_title,all_data_content =\
        all_char_title,all_char_content=      question_d['title_char'],question_d['content_char']
            # all_data_title,all_data_content =\
        all_word_title,all_word_content=     question_d['title_word'],question_d['content_word']

        self.train_data = (all_char_title[:-20000],all_char_content[:-20000]),( all_word_title[:-20000],all_word_content[:-20000])
        self.val_data = (all_char_title[-20000:],all_char_content[-20000:]), (all_word_title[-20000:],all_word_content[-20000:])
        self.all_num = len(all_char_title)
        # del all_data_title,all_data_content

        self.data_title,self.data_content = self.train_data
        self.len_ = len(self.data_title[0])

        self.index2qid = question_d['index2qid'].item()
        self.l_end=0
        self.labels = labels_['d']
dataset.py 文件源码 项目:PyTorchText 作者: chenyuntc 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def __init__(self,train_root,labels_file,type_='char',augument=True):
        '''
        Dataset('/mnt/7/zhihu/ieee_zhihu_cup/train.npz','/mnt/7/zhihu/ieee_zhihu_cup/a.json')
        '''
        import json
        with open(labels_file) as f:
            labels_ = json.load(f)

        # embedding_d = np.load(embedding_root)['vector']
        self.augument=augument
        question_d = np.load(train_root)
        self.type_=type_
        if type_ == 'char':
            all_data_title,all_data_content =\
                 question_d['title_char'],question_d['content_char']

        elif type_ == 'word':
            all_data_title,all_data_content =\
                 question_d['title_word'],question_d['content_word']

        self.train_data = all_data_title[:-200000],all_data_content[:-200000]
        self.val_data = all_data_title[-200000:],all_data_content[-200000:]

        self.all_num = len(all_data_content)
        # del all_data_title,all_data_content

        self.data_title,self.data_content = self.train_data
        self.len_ = len(self.data_title)

        self.index2qid = question_d['index2qid'].item()
        self.l_end=0
        self.labels = labels_['d']

        self.training=True
dataset.py 文件源码 项目:PyTorchText 作者: chenyuntc 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def __init__(self,train_root,labels_file,type_='char',fold=0):
        '''
        Dataset('/mnt/7/zhihu/ieee_zhihu_cup/train.npz','/mnt/7/zhihu/ieee_zhihu_cup/a.json')
        '''
        import json
        with open(labels_file) as f:
            labels_ = json.load(f)
        self.fold=fold 
        # embedding_d = np.load(embedding_root)['vector']
        question_d = np.load(train_root)

            # all_data_title,all_data_content =\
        all_char_title,all_char_content=      question_d['title_char'],question_d['content_char']
            # all_data_title,all_data_content =\
        all_word_title,all_word_content=     question_d['title_word'],question_d['content_word']

        self.train_data = (all_char_title[:-200000],all_char_content[:-200000]),( all_word_title[:-200000],all_word_content[:-200000])
        self.val_data = (all_char_title[-200000:],all_char_content[-200000:]), (all_word_title[-200000:],all_word_content[-200000:])
        self.all_num = len(all_char_title)
        # del all_data_title,all_data_content

        self.data_title,self.data_content = self.train_data
        self.len_ = len(self.data_title[0])
        self.training=True
        self.index2qid = question_d['index2qid'].item()
        self.l_end=0
        self.labels = labels_['d']
fold_dataset.py 文件源码 项目:PyTorchText 作者: chenyuntc 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def __init__(self,train_root,labels_file,type_='char',fold=0):
        '''
        Dataset('/mnt/7/zhihu/ieee_zhihu_cup/train.npz','/mnt/7/zhihu/ieee_zhihu_cup/a.json')
        '''
        import json
        with open(labels_file) as f:
            labels_ = json.load(f)
        self.fold =fold

        # embedding_d = np.load(embedding_root)['vector']
        question_d = np.load(train_root)
        self.type_=type_
        if type_ == 'char':
            all_data_title,all_data_content =\
                 question_d['title_char'],question_d['content_char']

        elif type_ == 'word':
            all_data_title,all_data_content =\
                 question_d['title_word'],question_d['content_word']

        self.train_data = all_data_title[:-200000],all_data_content[:-200000]
        self.val_data = all_data_title[-200000:],all_data_content[-200000:]

        self.all_num = len(all_data_content)
        # del all_data_title,all_data_content

        self.data_title,self.data_content = self.train_data
        self.len_ = len(self.data_title)

        self.training=True

        self.index2qid = question_d['index2qid'].item()
        self.l_end=0
        self.labels = labels_['d']
data.py 文件源码 项目:sk-torch 作者: mattHawthorn 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def __init__(self, dataset: Dataset, X_encoder: Opt[Callable[[T1], TensorType]]=None,
                 y_encoder: Opt[Callable[[T2], TensorType]]=None, **kwargs):
        super().__init__(dataset=dataset, **kwargs)
        self.X_encoder = X_encoder
        self.y_encoder = y_encoder
test_dataloader.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 17 收藏 0 点赞 0 评论 0
def test_numpy(self):
        import numpy as np

        class TestDataset(torch.utils.data.Dataset):
            def __getitem__(self, i):
                return np.ones((2, 3, 4)) * i

            def __len__(self):
                return 1000

        loader = DataLoader(TestDataset(), batch_size=12)
        batch = next(iter(loader))
        self.assertIsInstance(batch, torch.DoubleTensor)
        self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4]))
test_dataloader.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def test_numpy_scalars(self):
        import numpy as np

        class ScalarDataset(torch.utils.data.Dataset):
            def __init__(self, dtype):
                self.dtype = dtype

            def __getitem__(self, i):
                return self.dtype()

            def __len__(self):
                return 4

        dtypes = {
            np.float64: torch.DoubleTensor,
            np.float32: torch.FloatTensor,
            np.float16: torch.HalfTensor,
            np.int64: torch.LongTensor,
            np.int32: torch.IntTensor,
            np.int16: torch.ShortTensor,
            np.int8: torch.CharTensor,
            np.uint8: torch.ByteTensor,
        }
        for dt, tt in dtypes.items():
            dset = ScalarDataset(dt)
            loader = DataLoader(dset, batch_size=2)
            batch = next(iter(loader))
            self.assertIsInstance(batch, tt)
omniglot.py 文件源码 项目:MatchingNetworks 作者: gitabcworld 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def __init__(self, root, transform=None, target_transform=None, download=False):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.'
                               + ' You can use download=True to download it')

        self.all_items=find_classes(os.path.join(self.root, self.processed_folder))
        self.idx_classes=index_classes(self.all_items)
miniImagenetOneShot.py 文件源码 项目:MatchingNetworks 作者: gitabcworld 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def __init__(self, dataroot = '/home/aberenguel/Dataset/miniImagenet', type = 'train',
                 nEpisodes = 1000, classes_per_set=10, samples_per_class=1):

        self.nEpisodes = nEpisodes
        self.classes_per_set = classes_per_set
        self.samples_per_class = samples_per_class
        self.n_samples = self.samples_per_class * self.classes_per_set
        self.n_samplesNShot = 5 # Samples per meta-test. In this case 1 as is OneShot.
        # Transformations to the image
        self.transform = transforms.Compose([filenameToPILImage,
                                             PiLImageResize,
                                             transforms.ToTensor()
                                             ])

        def loadSplit(splitFile):
            dictLabels = {}
            with open(splitFile) as csvfile:
                csvreader = csv.reader(csvfile, delimiter=',')
                next(csvreader, None)
                for i,row in enumerate(csvreader):
                    filename = row[0]
                    label = row[1]
                    if label in dictLabels.keys():
                        dictLabels[label].append(filename)
                    else:
                        dictLabels[label] = [filename]
            return dictLabels

        #requiredFiles = ['train','val','test']
        self.miniImagenetImagesDir = os.path.join(dataroot,'images')
        self.data = loadSplit(splitFile = os.path.join(dataroot,type + '.csv'))
        self.data = collections.OrderedDict(sorted(self.data.items()))
        self.classes_dict = {self.data.keys()[i]:i  for i in range(len(self.data.keys()))}
        self.create_episodes(self.nEpisodes)
test_dataloader.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def test_numpy(self):
        import numpy as np

        class TestDataset(torch.utils.data.Dataset):
            def __getitem__(self, i):
                return np.ones((2, 3, 4)) * i

            def __len__(self):
                return 1000

        loader = DataLoader(TestDataset(), batch_size=12)
        batch = next(iter(loader))
        self.assertIsInstance(batch, torch.DoubleTensor)
        self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4]))
test_dataloader.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def test_numpy_scalars(self):
        import numpy as np

        class ScalarDataset(torch.utils.data.Dataset):
            def __init__(self, dtype):
                self.dtype = dtype

            def __getitem__(self, i):
                return self.dtype()

            def __len__(self):
                return 4

        dtypes = {
            np.float64: torch.DoubleTensor,
            np.float32: torch.FloatTensor,
            np.float16: torch.HalfTensor,
            np.int64: torch.LongTensor,
            np.int32: torch.IntTensor,
            np.int16: torch.ShortTensor,
            np.int8: torch.CharTensor,
            np.uint8: torch.ByteTensor,
        }
        for dt, tt in dtypes.items():
            dset = ScalarDataset(dt)
            loader = DataLoader(dset, batch_size=2)
            batch = next(iter(loader))
            self.assertIsInstance(batch, tt)
test_dataloader.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 35 收藏 0 点赞 0 评论 0
def test_numpy(self):
        import numpy as np

        class TestDataset(torch.utils.data.Dataset):
            def __getitem__(self, i):
                return np.ones((2, 3, 4)) * i

            def __len__(self):
                return 1000

        loader = DataLoader(TestDataset(), batch_size=12)
        batch = next(iter(loader))
        self.assertIsInstance(batch, torch.DoubleTensor)
        self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4]))
test_dataloader.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def test_numpy_scalars(self):
        import numpy as np

        class ScalarDataset(torch.utils.data.Dataset):
            def __init__(self, dtype):
                self.dtype = dtype

            def __getitem__(self, i):
                return self.dtype()

            def __len__(self):
                return 4

        dtypes = {
            np.float64: torch.DoubleTensor,
            np.float32: torch.FloatTensor,
            np.float16: torch.HalfTensor,
            np.int64: torch.LongTensor,
            np.int32: torch.IntTensor,
            np.int16: torch.ShortTensor,
            np.int8: torch.CharTensor,
            np.uint8: torch.ByteTensor,
        }
        for dt, tt in dtypes.items():
            dset = ScalarDataset(dt)
            loader = DataLoader(dset, batch_size=2)
            batch = next(iter(loader))
            self.assertIsInstance(batch, tt)


问题


面经


文章

微信
公众号

扫码关注公众号