def load_mnist():
mnist = fetch_mldata('MNIST original')
mnist_X, mnist_y = shuffle(mnist.data, mnist.target, random_state=seed)
mnist_X = mnist_X / 255.0
# pytorch?????
mnist_X, mnist_y = mnist_X.astype('float32'), mnist_y.astype('int64')
# 2?????????????????1?????
def flatten_img(images):
'''
images: shape => (n, rows, columns)
output: shape => (n, rows*columns)
'''
n_rows = images.shape[1]
n_columns = images.shape[2]
for num in range(n_rows):
if num % 2 != 0:
images[:, num, :] = images[:, num, :][:, ::-1]
output = images.reshape(-1, n_rows*n_columns)
return output
mnist_X = mnist_X.reshape(-1, 28, 28)
mnist_X = flatten_img(mnist_X) # X.shape => (n_samples, seq_len)
mnist_X = mnist_X[:, :, np.newaxis] # X.shape => (n_samples, seq_len, n_features)
# ????????????
train_X, test_X, train_y, test_y = train_test_split(mnist_X, mnist_y,
test_size=0.2,
random_state=seed)
return train_X, test_X, train_y, test_y
评论列表
文章目录