def __init__(self, root, transform=None, target_transform=None,
train=True, test=False, top=100, group=True,
additional_transform=None):
self.root = root
self.transform = transform
self.additional_transform = additional_transform
self.target_transform = target_transform
self.top_objects = top
self.top_folder = 'top_{0}'.format(top)
self.group = group
if not osp.exists(self.root):
raise RuntimeError('Dataset not found ' +
'please download it from: ' +
'http://visualgenome.org/api/v0/api_home.html')
if not self.__check_exists():
self.process_dataset()
# self.region_objects, self.obj_idx = self.load_region_objects()
if train:
train_file = osp.join(self.data_path, self.top_folder,
self.region_train_file)
with open(train_file, 'rb') as f:
self.regions = torch.load(f)
elif test:
test_file = osp.join(self.data_path, self.top_folder,
self.region_test_file)
with open(test_file, 'rb') as f:
self.regions = torch.load(f)
else:
val_file = osp.join(self.data_path, self.top_folder,
self.region_val_file)
with open(val_file, 'rb') as f:
self.regions = torch.load(f)
if self.group:
self.regions = self.__group_regions_by_id(self.regions)
corpus_file = osp.join(self.data_path, self.processed_folder,
self.corpus_file)
with open(corpus_file, 'rb') as f:
self.corpus = torch.load(f)
region_obj_file = osp.join(self.data_path, self.top_folder,
self.region_objects_file)
with open(region_obj_file, 'rb') as f:
self.region_objects = torch.load(f)
obj_idx_path = osp.join(self.data_path, self.top_folder,
self.obj_idx_file)
with open(obj_idx_path, 'rb') as f:
self.obj_idx = torch.load(f)
self.idx_obj = {v: k for k, v in self.obj_idx.items()}
# del region_objects
评论列表
文章目录