tf-keras-skeleton.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:LIE 作者: EmbraceLife 项目源码 文件源码
def get_test_data(train_samples,
                      test_samples,
                      input_shape,
                      num_classes):
      """Generates test data to train a model on.

      Arguments:
        train_samples: Integer, how many training samples to generate.
        test_samples: Integer, how many test samples to generate.
        input_shape: Tuple of integers, shape of the inputs.
        num_classes: Integer, number of classes for the data and targets.
          Only relevant if `classification=True`.

      Returns:
        A tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
      """
      num_sample = train_samples + test_samples
      templates = 2 * num_classes * np.random.random((num_classes,) + input_shape)
      y = np.random.randint(0, num_classes, size=(num_sample,))
      x = np.zeros((num_sample,) + input_shape)
      for i in range(num_sample):
        x[i] = templates[y[i]] + np.random.normal(loc=0, scale=1., size=input_shape)
      return ((x[:train_samples], y[:train_samples]),
              (x[train_samples:], y[train_samples:]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号