get_datasets.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:metaqnn 作者: bowenbaker 项目源码 文件源码
def get_svhn(save_dir=None, root_path=None):
    ''' If root_path is None, we download the data set from internet.

        Either save path or root path must not be None and not both.

        Returns Xtr, Ytr, Xte, Yte as numpy arrays
    '''

    assert((save_dir is not None and root_path is None) or (save_dir is None and root_path is not None))

    if root_path is None:
        new_save_dir = os.path.join(save_dir, 'og_data')
        if not os.path.isdir(new_save_dir):
            os.mkdir(new_save_dir)
        train_mat = os.path.join(new_save_dir, "train_32x32.mat")
        test_mat =  os.path.join(new_save_dir, "test_32x32.mat")
        url = urllib.URLopener()

        print 'Downloading Svhn Train...'
        url.retrieve("http://ufldl.stanford.edu/housenumbers/train_32x32.mat", train_mat)
        print 'Downloading Svhn Test...'
        url.retrieve("http://ufldl.stanford.edu/housenumbers/test_32x32.mat", test_mat)


    root = new_save_dir if not root_path else root_path

    train = io.loadmat(os.path.join(root, 'train_32x32.mat'))
    Xtr = train['X']
    Ytr = train['y']
    del train

    test = io.loadmat(os.path.join(root, 'test_32x32.mat'))
    Xte = test['X']
    Yte = test['y']
    del test

    Xtr = np.transpose(Xtr, (3, 2, 0, 1))
    Xte = np.transpose(Xte, (3, 2, 0, 1))
    Ytr = Ytr.reshape(Ytr.shape[:1]) - 1
    Yte = Yte.reshape(Yte.shape[:1]) - 1

    print 'Xtrain shape', Xtr.shape
    print 'Ytrain shape', Ytr.shape
    print 'Xtest shape', Xte.shape
    print 'Ytest shape', Yte.shape

    return Xtr, Ytr, Xte, Yte
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号