def load(batch_size, test_batch_size, n_labelled=None):
filepath = '/tmp/mnist.pkl.gz'
url = 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz'
if not os.path.isfile(filepath):
print "Couldn't find MNIST dataset in /tmp, downloading..."
urllib.urlretrieve(url, filepath)
with gzip.open('/tmp/mnist.pkl.gz', 'rb') as f:
train_data, dev_data, test_data = pickle.load(f)
return (
mnist_generator(train_data, batch_size, n_labelled),
mnist_generator(dev_data, test_batch_size, n_labelled),
mnist_generator(test_data, test_batch_size, n_labelled)
)
评论列表
文章目录