def get_mnist(nbatch=128):
mnist = fetch_mldata('MNIST original', data_home='/home/shaofan/.sklearn/')
x, y = mnist.data, mnist.target
x = x.reshape(-1, 1, 28, 28)
ind = np.random.permutation(x.shape[0])
x = x[ind]
y = y[ind]
def random_stream():
while 1:
yield x[np.random.choice(x.shape[0], replace=False, size=nbatch)].transpose(0, 2, 3, 1)
return x, y, random_stream
评论列表
文章目录