def _load_data(self, nb_obs=None):
"""Load the dataset specified by self.name
:param nb_obs: optional; int for the number of observations to retain
from the training & testing sets; if None, retain the full training
and testing sets
:return: a tuple of 4 np.ndarrays (x_train, y_train, x_test, y_test)
"""
dataset = getattr(keras.datasets, self.name)
train_data, test_data = dataset.load_data()
x_train, y_train = train_data[0] / 255., train_data[1]
x_test, y_test = test_data[0] / 255., test_data[1]
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
if self.name == 'mnist':
x_train = np.expand_dims(x_train, axis=-1)
x_test = np.expand_dims(x_test, axis=-1)
if nb_obs:
x_train = x_train[:nb_obs]
y_train = y_train[:nb_obs]
x_test = x_test[:nb_obs]
y_test = y_test[:nb_obs]
return x_train, y_train, x_test, y_test
评论列表
文章目录