def get_data(problem, n_train, n_batch):
if problem == 'cifar10':
# Load data
data_train, data_valid = G.misc.data.cifar10(False)
if problem == 'svhn':
# Load data
data_train, data_valid = G.misc.data.svhn(False, True)
elif problem == 'mnist':
# Load data
validset = False
if validset:
data_train, data_valid, data_test = G.misc.data.mnist_binarized(validset, False)
else:
data_train, data_valid = G.misc.data.mnist_binarized(validset, False)
data_train['x'] = data_train['x'].reshape((-1,1,28,28))
data_valid['x'] = data_valid['x'].reshape((-1,1,28,28))
elif problem == 'lfw':
data_train = G.misc.data.lfw(False,True)
data_valid = G.ndict.getRows(data_train, 0, 1000)
data_init = {'x':data_train['x'][:n_batch]}
if n_train > 0:
data_train = G.ndict.getRows(data_train, 0, n_train)
data_valid = G.ndict.getRows(data_valid, 0, n_train)
return data_train, data_valid, data_init