def get_data(data_name='mnist', test_flag=False):
if data_name == 'daudi':
(X_train, y_train), (X_test, y_test) = daudi_load_data()
if test_flag:
X_train = X_test
# approximately -0.2+1 to 0.2+1 --> -1. 1
X_train = (X_train - 1.0) * 5.0
X_train = X_train.reshape((X_train.shape[0], 1) + X_train.shape[1:])
else:
(X_train, y_train), (X_test, y_test) = mnist.load_data()
if test_flag:
X_train = X_test
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = X_train.reshape((X_train.shape[0], 1) + X_train.shape[1:])
return X_train
评论列表
文章目录