def get_test_data(train_samples,
test_samples,
input_shape,
num_classes):
"""Generates test data to train a model on.
Arguments:
train_samples: Integer, how many training samples to generate.
test_samples: Integer, how many test samples to generate.
input_shape: Tuple of integers, shape of the inputs.
num_classes: Integer, number of classes for the data and targets.
Only relevant if `classification=True`.
Returns:
A tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
"""
num_sample = train_samples + test_samples
templates = 2 * num_classes * np.random.random((num_classes,) + input_shape)
y = np.random.randint(0, num_classes, size=(num_sample,))
x = np.zeros((num_sample,) + input_shape)
for i in range(num_sample):
x[i] = templates[y[i]] + np.random.normal(loc=0, scale=1., size=input_shape)
return ((x[:train_samples], y[:train_samples]),
(x[train_samples:], y[train_samples:]))
评论列表
文章目录