nn_keras_digits.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:python_utils 作者: Jayhello 项目源码 文件源码
def generate_data():
    (X_train, y_train), (X_test, y_test) = load_data()

    # flatten 28*28 images to a 784 vector for each image
    print X_train.shape[1], X_train.shape[2], X_train.shape
    # X_train.shape -> (60000L, 28L, 28L)
    num_pixels = X_train.shape[1] * X_train.shape[2]
    X_train = X_train.reshape(X_train.shape[0], num_pixels).astype('float32')
    X_test = X_test.reshape(X_test.shape[0], num_pixels).astype('float32')

    # normalize inputs from 0-255 to 0-1
    X_train = X_train / 255
    X_test = X_test / 255

    y_train = np_utils.to_categorical(y_train)
    y_test = np_utils.to_categorical(y_test)
    # print y_train.shape, y_test.shape
    # y_train.shape -> (60000L, 10L), y_test.shape -> (10000L, 10L)
    num_classes = y_test.shape[1]

    return X_train, y_train, X_test, y_test
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号