def get_svhn_full(save_dir=None, root_path=None):
''' If root_path is None, we download the data set from internet.
Either save path or root path must not be None and not both.
Returns Xtr, Ytr, Xte, Yte as numpy arrays
'''
assert((save_dir is not None and root_path is None) or (save_dir is None and root_path is not None))
Xtr_small, Ytr_small, Xte, Yte = get_svhn(save_dir, root_path)
if root_path is None:
new_save_dir = os.path.join(save_dir, 'og_data')
if not os.path.isdir(new_save_dir):
os.mkdir(new_save_dir)
extra_mat = os.path.join(new_save_dir, "extra_32x32.mat")
url = urllib.URLopener()
print 'Downloading Svhn Extra...'
url.retrieve("http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", extra_mat)
root = new_save_dir if not root_path else root_path
extra = io.loadmat(os.path.join(root, 'extra_32x32.mat'))
Xtr_extra = extra['X']
Ytr_extra = extra['y']
Xtr_extra = np.transpose(Xtr_extra, (3, 2, 0, 1))
Ytr_extra = Ytr_extra.reshape(Ytr_extra.shape[:1]) - 1
print 'Xextra shape', Xtr_extra.shape
print 'Yextra shape', Ytr_extra.shape
val_x = []
val_y = []
train_x = []
train_y = []
for i in np.unique(Ytr_small):
# Get 400 images from X_small
X_small_label = Xtr_small[Ytr_small == i]
val_x.append(X_small_label[:400])
val_y.append([i]*400)
train_x.append(X_small_label[400:])
train_y.append([i]*(X_small_label.shape[0] - 400))
# Get 200 images from X_small
X_extra_label = Xtr_extra[Ytr_extra == i]
val_x.append(X_extra_label[:200])
val_y.append([i]*200)
train_x.append(X_extra_label[200:])
train_y.append([i]*(X_extra_label.shape[0] - 200))
Xtr = np.concatenate(train_x)
Ytr = np.concatenate(train_y)
Xval = np.concatenate(val_x)
Yval = np.concatenate(val_y)
return Xtr, Ytr, Xval, Yval, Xte, Yte
评论列表
文章目录