def mnist():
fd = open(os.path.join(data_dir, 'train-images-idx3-ubyte'))
loaded = np.fromfile(file=fd, dtype=np.uint8)
trX = loaded[16:].reshape((60000, 28 * 28)).astype(float)
fd = open(os.path.join(data_dir, 'train-labels-idx1-ubyte'))
loaded = np.fromfile(file=fd, dtype=np.uint8)
trY = loaded[8:].reshape((60000))
fd = open(os.path.join(data_dir, 't10k-images-idx3-ubyte'))
loaded = np.fromfile(file=fd, dtype=np.uint8)
teX = loaded[16:].reshape((10000, 28 * 28)).astype(float)
fd = open(os.path.join(data_dir, 't10k-labels-idx1-ubyte'))
loaded = np.fromfile(file=fd, dtype=np.uint8)
teY = loaded[8:].reshape((10000))
trY = np.asarray(trY)
teY = np.asarray(teY)
return trX, teX, trY, teY
评论列表
文章目录