data.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:a3c 作者: siemanko 项目源码 文件源码
def __init__(self, batch_size, validation_size):
        self.batch_size = batch_size

        # Load MNIST
        mnist = fetch_mldata('MNIST original')
        X, Y_labels = mnist['data'], mnist['target']

        # normalize X to (0.0, 1.0) range
        X = X.astype(np.float32) / 255.0

        # one hot encode the labels
        Y = np.zeros((len(Y_labels), 10))
        Y[range(len(Y_labels)), Y_labels.astype(np.int32)] = 1.

        # ensure type is float32
        X = X.astype(np.float32)
        Y = Y.astype(np.float32)

        # shuffle examples
        permutation = np.random.permutation(len(X))
        X = X[permutation]
        Y = Y[permutation]

        # split into train, validate, test
        train_end      = 60000 - validation_size
        validation_end = 60000
        test_end       = 70000

        self.X_train = X[0:train_end]
        self.X_valid = X[train_end:validation_end]
        self.X_test  = X[validation_end:test_end]

        self.Y_train = Y[0:train_end]
        self.Y_valid = Y[train_end:validation_end]
        self.Y_test  = Y[validation_end:test_end]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号