def __init__(self, root, train=True, transform=None, download=False):
"""Init USPS dataset."""
# init params
self.root = os.path.expanduser(root)
self.filename = "usps_28x28.pkl"
self.train = train
# Num of Train = 7438, Num ot Test 1860
self.transform = transform
self.dataset_size = None
# download dataset.
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found." +
" You can use download=True to download it")
self.train_data, self.train_labels = self.load_samples()
if self.train:
total_num_samples = self.train_labels.shape[0]
indices = np.arange(total_num_samples)
np.random.shuffle(indices)
self.train_data = self.train_data[indices[0:self.dataset_size], ::]
self.train_labels = self.train_labels[indices[0:self.dataset_size]]
self.train_data *= 255.0
self.train_data = self.train_data.transpose(
(0, 2, 3, 1)) # convert to HWC
评论列表
文章目录