test_inferno.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:inferno 作者: inferno-pytorch 项目源码 文件源码
def setUpDatasets(self):
        # Build training dataset
        inputs, targets = self.generate_random_data(self.NUM_SAMPLES, (3, 32, 32),
                                                    num_classes=self.NUM_CLASSES,
                                                    dtype='float32')
        # Split to train and split
        train_inputs, train_targets = inputs[:self.NUM_TRAINING_SAMPLES], \
                                      targets[:self.NUM_TRAINING_SAMPLES]
        validate_inputs, validate_targets = inputs[self.NUM_TRAINING_SAMPLES:], \
                                            targets[self.NUM_TRAINING_SAMPLES:]
        # Convert to tensor and build dataset
        train_dataset = TensorDataset(torch.from_numpy(train_inputs),
                                      torch.from_numpy(train_targets))
        validate_dataset = TensorDataset(torch.from_numpy(validate_inputs),
                                         torch.from_numpy(validate_targets))
        # Build dataloaders from dataset
        self.train_loader = DataLoader(train_dataset, batch_size=16,
                                       shuffle=True, num_workers=2, pin_memory=False)
        self.validate_loader = DataLoader(validate_dataset, batch_size=16,
                                          shuffle=True, num_workers=2, pin_memory=False)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号