main_STL.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:scalingscattering 作者: edouardoyallon 项目源码 文件源码
def create_dataset(opt, mode,fold=0):

    convert = tnt.transform.compose([
        lambda x: x.astype(np.float32),
        lambda x: x / 255.0,
        # cvtransforms.Normalize([125.3, 123.0, 113.9], [63.0,  62.1,  66.7]),
        lambda x: x.transpose(2, 0, 1).astype(np.float32),
        torch.from_numpy,
    ])

    train_transform = tnt.transform.compose([
        cvtransforms.RandomHorizontalFlip(),
        cvtransforms.Pad(opt.randomcrop_pad, cv2.BORDER_REFLECT),
        cvtransforms.RandomCrop(96),
        convert,
    ])

    smode = 'train' if mode else 'test'
    ds = getattr(datasets, opt.dataset)('.', split=smode, download=True)
    if mode:
        if fold>-1:
            folds_idx = [map(int, v.split(' ')[:-1])
                for v in [line.replace('\n', '')
                    for line in open('./stl10_binary/fold_indices.txt')]][fold]


            ds = tnt.dataset.TensorDataset([
                getattr(ds, 'data').transpose(0, 2, 3, 1)[folds_idx],
                getattr(ds, 'labels')[folds_idx].tolist()])
        else:
            ds = tnt.dataset.TensorDataset([
                getattr(ds, 'data').transpose(0, 2, 3, 1),
                getattr(ds, 'labels').tolist()])

    else:
       ds = tnt.dataset.TensorDataset([
        getattr(ds, 'data').transpose(0, 2, 3, 1),
        getattr(ds, 'labels').tolist()])
    return ds.transform({0: train_transform if mode else convert})
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号