reversing_gan.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:gandlf 作者: codekansas 项目源码 文件源码
def get_mnist_data(binarize=False):
    """Puts the MNIST data in the right format."""

    (X_train, y_train), (X_test, y_test) = mnist.load_data()

    if binarize:
        X_test = np.where(X_test >= 10, 1, -1)
        X_train = np.where(X_train >= 10, 1, -1)
    else:
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_test = (X_test.astype(np.float32) - 127.5) / 127.5

    X_train = np.expand_dims(X_train, axis=-1)
    X_test = np.expand_dims(X_test, axis=-1)

    y_train = np.eye(10)[y_train]
    y_test = np.eye(10)[y_test]

    return (X_train, y_train), (X_test, y_test)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号