def get_data():
# data
print("loading data, please wait ...")
mnist = fetch_mldata('MNIST original', data_home=os.path.join(os.path.dirname(__file__), './data'))
print('data loading is done ...')
X_train = mnist.data / 255.0
y_train = mnist.target
n_classes = np.unique(y_train).size
return n_classes, X_train, y_train
评论列表
文章目录