def mnist(digit='all', n_samples=0, return_gt=False):
mnist = sk_datasets.fetch_mldata('MNIST original')
X = mnist.data
gt = mnist.target
if digit == 'all': # keep all digits
pass
else:
X = X[gt == digit, :]
gt = gt[gt == digit]
if n_samples > len(X):
raise ValueError('Requesting {} samples'
'from {} datapoints'.format(n_samples, len(X)))
if n_samples > 0:
np.random.seed(0)
selection = np.random.randint(len(X), size=n_samples)
X = X[selection, :]
gt = gt[selection]
idx = np.argsort(gt)
X = X[idx, :]
gt = gt[idx]
if return_gt:
return X, gt
else:
return X
评论列表
文章目录