def load_cifar10_batch(fpath, one_hot=True, as_float=True):
with open(fpath, 'rb') as f:
# https://stackoverflow.com/questions/11305790
data = cPickle.load(f, encoding='latin1')
X = np.copy(data['data']).reshape(-1, 32*32, 3, order='F')
X = X.reshape(-1, 32, 32, 3)
Y = np.array(data['labels'])
# Convert labels to one hot
if one_hot:
Y = to_one_hot(Y)
# CONVERT TO FLOAT [0,1] TYPE HERE to be consistent with skimage TFs!!!
# See: http://scikit-image.org/docs/dev/user_guide/data_types.html
if as_float:
X = img_as_float(X)
return X, Y
评论列表
文章目录