main_small_sample_class_normalized.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:scalingscattering 作者: edouardoyallon 项目源码 文件源码
def create_dataset(opt, mode):
    convert = tnt.transform.compose([
        lambda x: x.astype(np.float32),
        lambda x: x / 255.0,
        # cvtransforms.Normalize([125.3, 123.0, 113.9], [63.0,  62.1,  66.7]),
        lambda x: x.transpose(2, 0, 1).astype(np.float32),
        torch.from_numpy,
    ])

    train_transform = tnt.transform.compose([
        cvtransforms.RandomHorizontalFlip(),
        cvtransforms.Pad(opt.randomcrop_pad, cv2.BORDER_REFLECT),
        cvtransforms.RandomCrop(32),
        convert,
    ])


    ds = getattr(datasets, opt.dataset)('.', train=mode, download=True)
    smode = 'train' if mode else 'test'
    if mode:
        from numpy.random import RandomState
        prng = RandomState(opt.seed)

        assert(opt.sampleSize%10==0)

        random_permute=prng.permutation(np.arange(0,5000))[0:opt.sampleSize/10]

        labels = np.array(getattr(ds,'train_labels'))
        data = getattr(ds,'train_data')

        classes=np.unique(labels)
        inds_all=np.array([],dtype='int32')
        for cl in classes:
            inds=np.where(np.array(labels)==cl)[0][random_permute]
            inds_all=np.r_[inds,inds_all]

        ds = tnt.dataset.TensorDataset([
            data[inds_all,:].transpose(0, 2, 3, 1),
            labels[inds_all].tolist()])
    else:
        ds = tnt.dataset.TensorDataset([
            getattr(ds, smode + '_data').transpose(0, 2, 3, 1),
            getattr(ds, smode + '_labels')])
    return ds.transform({0: train_transform if mode else convert})
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号