autoencoder.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:dsde-deep-learning 作者: broadinstitute 项目源码 文件源码
def load_mnist(flatten=True):
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    x_train = x_train.astype('float32') / 255.
    x_test = x_test.astype('float32') / 255.

    if flatten:
        x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
        x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
    else:
        x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))  # adapt this if using `channels_first` image data format
        x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))  # adapt this if using `channels_first` image data format

    print(x_train.shape)
    print(x_test.shape)

    return (x_train, y_train), (x_test, y_test)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号