mnist_test.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:cyclegan_keras 作者: shadySource 项目源码 文件源码
def load_mnist():
    '''
    returns mnist_data
    '''
    # input image dimensions
    img_rows, img_cols = 28, 28

    # the data, shuffled and split between train and test sets
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    if k.image_data_format() == 'channels_first':
        x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
        input_shape = (1, img_rows, img_cols)
    else:
        x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
        input_shape = (img_rows, img_cols, 1)

    x_train = x_train.astype(k.floatx())
    x_train *= 0.96/255
    x_train += 0.02
    return input_shape, x_train
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号