utils.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:deepnet 作者: parasdahal 项目源码 文件源码
def load_mnist(path, num_training=50000, num_test=10000, cnn=True, one_hot=False):
    f = gzip.open(path, 'rb')
    training_data, validation_data, test_data = cPickle.load(
        f, encoding='iso-8859-1')
    f.close()
    X_train, y_train = training_data
    X_validation, y_validation = validation_data
    X_test, y_test = test_data
    if cnn:
        shape = (-1, 1, 28, 28)
        X_train = X_train.reshape(shape)
        X_validation = X_validation.reshape(shape)
        X_test = X_test.reshape(shape)
    if one_hot:
        y_train = one_hot_encode(y_train, 10)
        y_validation = one_hot_encode(y_validation, 10)
        y_test = one_hot_encode(y_test, 10)
    X_train, y_train = X_train[range(
        num_training)], y_train[range(num_training)]
    X_test, y_test = X_test[range(num_test)], y_test[range(num_test)]
    return (X_train, y_train), (X_test, y_test)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号