def imagenet_transform(scale_size=256, input_size=224, train=True, allow_var_size=False):
normalize = {'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225]}
if train:
return transforms.Compose([
transforms.Scale(scale_size),
transforms.RandomCrop(input_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**normalize)
])
elif allow_var_size:
return transforms.Compose([
transforms.Scale(scale_size),
transforms.ToTensor(),
transforms.Normalize(**normalize)
])
else:
return transforms.Compose([
transforms.Scale(scale_size),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize(**normalize)
])
评论列表
文章目录