rocket_bottom.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Rocket-Launching 作者: zhougr1993 项目源码 文件源码
def create_dataset(opt, mode):
    convert = tnt.transform.compose([
        lambda x: x.astype(np.float32),
        T.Normalize([125.3, 123.0, 113.9], [63.0, 62.1, 66.7]),
        lambda x: x.transpose(2, 0, 1).astype(np.float32),
        torch.from_numpy,
    ])

    train_transform = tnt.transform.compose([
        T.RandomHorizontalFlip(),
        T.Pad(opt.randomcrop_pad, cv2.BORDER_REFLECT),
        T.RandomCrop(32),
        convert,
    ])

    ds = getattr(datasets, opt.dataset)(
        opt.data_root, train=mode, download=True)
    smode = 'train' if mode else 'test'
    ds = tnt.dataset.TensorDataset([
        getattr(ds, smode + '_data'),
        getattr(ds, smode + '_labels')])
    return ds.transform({0: train_transform if mode else convert})
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号