cifar.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:attention-transfer 作者: szagoruyko 项目源码 文件源码
def create_dataset(opt, mode):
    convert = tnt.transform.compose([
        lambda x: x.astype(np.float32),
        T.Normalize([125.3, 123.0, 113.9], [63.0, 62.1, 66.7]),
        lambda x: x.transpose(2,0,1).astype(np.float32),
        torch.from_numpy,
    ])

    train_transform = tnt.transform.compose([
        T.RandomHorizontalFlip(),
        T.Pad(opt.randomcrop_pad, cv2.BORDER_REFLECT),
        T.RandomCrop(32),
        convert,
    ])

    ds = getattr(datasets, opt.dataset)(opt.data_root, train=mode, download=True)
    smode = 'train' if mode else 'test'
    ds = tnt.dataset.TensorDataset([
        getattr(ds, smode+'_data'),
        getattr(ds, smode+'_labels')])
    return ds.transform({0: train_transform if mode else convert})
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号