def get_mnist_data(is_train, image_size, batchsize):
ds = MNISTCh('train' if is_train else 'test', shuffle=True)
if is_train:
augs = [
imgaug.RandomApplyAug(imgaug.RandomResize((0.8, 1.2), (0.8, 1.2)), 0.3),
imgaug.RandomApplyAug(imgaug.RotationAndCropValid(15), 0.5),
imgaug.RandomApplyAug(imgaug.SaltPepperNoise(white_prob=0.01, black_prob=0.01), 0.25),
imgaug.Resize((224, 224), cv2.INTER_AREA)
]
ds = AugmentImageComponent(ds, augs)
ds = PrefetchData(ds, 128*10, multiprocessing.cpu_count())
ds = BatchData(ds, batchsize)
ds = PrefetchData(ds, 256, 4)
else:
# no augmentation, only resizing
augs = [
imgaug.Resize((image_size, image_size), cv2.INTER_CUBIC),
]
ds = AugmentImageComponent(ds, augs)
ds = BatchData(ds, batchsize)
ds = PrefetchData(ds, 20, 2)
return ds
评论列表
文章目录