visual_genome_loader.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:textobjdetection 作者: andfoy 项目源码 文件源码
def __init__(self, root, transform=None, target_transform=None,
                 train=True, test=False, top=100, group=True,
                 additional_transform=None):
        self.root = root
        self.transform = transform
        self.additional_transform = additional_transform
        self.target_transform = target_transform
        self.top_objects = top
        self.top_folder = 'top_{0}'.format(top)
        self.group = group

        if not osp.exists(self.root):
            raise RuntimeError('Dataset not found ' +
                               'please download it from: ' +
                               'http://visualgenome.org/api/v0/api_home.html')

        if not self.__check_exists():
            self.process_dataset()

        # self.region_objects, self.obj_idx = self.load_region_objects()

        if train:
            train_file = osp.join(self.data_path, self.top_folder,
                                  self.region_train_file)
            with open(train_file, 'rb') as f:
                self.regions = torch.load(f)
        elif test:
            test_file = osp.join(self.data_path, self.top_folder,
                                 self.region_test_file)
            with open(test_file, 'rb') as f:
                self.regions = torch.load(f)
        else:
            val_file = osp.join(self.data_path, self.top_folder,
                                self.region_val_file)
            with open(val_file, 'rb') as f:
                self.regions = torch.load(f)

        if self.group:
            self.regions = self.__group_regions_by_id(self.regions)

        corpus_file = osp.join(self.data_path, self.processed_folder,
                               self.corpus_file)
        with open(corpus_file, 'rb') as f:
            self.corpus = torch.load(f)

        region_obj_file = osp.join(self.data_path, self.top_folder,
                                   self.region_objects_file)
        with open(region_obj_file, 'rb') as f:
            self.region_objects = torch.load(f)

        obj_idx_path = osp.join(self.data_path, self.top_folder,
                                self.obj_idx_file)

        with open(obj_idx_path, 'rb') as f:
            self.obj_idx = torch.load(f)

        self.idx_obj = {v: k for k, v in self.obj_idx.items()}
        # del region_objects
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号