def get_augmented_test_set(data_root, idx_file,
scale_size, crop_size, aug_type='ten_crop',
seg_root=None, mixture=False):
dsets = []
if aug_type == 'ten_crop':
crop_types = [0, 1, 2, 3, 4]
# 0: center crop,
# 1: top left crop, 2: top right crop
# 3: bottom right crop, 4: bottom left crop
flips = [0, 1] # 0: no flip, 1: horizontal flip
for i in crop_types:
for j in flips:
data_transform = transforms.Compose([
transforms.Scale(scale_size),
# transforms.CenterCrop(crop_size),
transforms.ToTensor(),
RandomFlip(flips[j]),
SpecialCrop((crop_size, crop_size), crop_type=crop_types[i]),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
if mixture:
seg_transform = transforms.Compose([
transforms.Scale(crop_size),
# transforms.CenterCrop(crop_size),
transforms.ToTensor(),
RandomFlip(flips[j]),
# SpecialCrop(crop_size=(crop_size, crop_size), crop_type=crop_types[i]),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
dsets.append(MyImageFolder(root = data_root,
idx_file = idx_file,
transform = data_transform,
seg_transform = seg_transform,
seg_root = seg_root))
else:
dsets.append(MyImageFolder(root = data_root,
idx_file = idx_file,
transform = data_transform))
return dsets
评论列表
文章目录