def cifar10():
channel_stats = dict(mean=[0.4914, 0.4822, 0.4465],
std=[0.2470, 0.2435, 0.2616])
train_transformation = data.TransformTwice(transforms.Compose([
data.RandomTranslateWithReflect(4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**channel_stats)
]))
eval_transformation = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(**channel_stats)
])
return {
'train_transformation': train_transformation,
'eval_transformation': eval_transformation,
'datadir': 'data-local/images/cifar/cifar10/by-image',
'num_classes': 10
}
评论列表
文章目录