train.py 文件源码

python
阅读 81 收藏 0 点赞 0 评论 0

项目:kaggle-dstl 作者: lopuhin 项目源码 文件源码
def load_image(self, im_id: str) -> Image:
        logger.info('Loading {}'.format(im_id))
        im_cache = Path('im_cache')
        im_cache.mkdir(exist_ok=True)
        im_data_path = im_cache.joinpath('{}.data'.format(im_id))
        mask_path = im_cache.joinpath('{}.mask'.format(im_id))
        if im_data_path.exists():
            im_data = np.load(str(im_data_path))
        else:
            im_data = self.preprocess_image(utils.load_image(im_id))
            with im_data_path.open('wb') as f:
                np.save(f, im_data)
        pre_buffer = self.hps.pre_buffer
        if mask_path.exists() and not pre_buffer:
            mask = np.load(str(mask_path))
        else:
            im_size = im_data.shape[1:]
            poly_by_type = utils.load_polygons(im_id, im_size)
            if pre_buffer:
                structures = 2
                poly_by_type[structures] = utils.to_multipolygon(
                    poly_by_type[structures].buffer(pre_buffer))
            mask = np.array(
                [utils.mask_for_polygons(im_size, poly_by_type[cls + 1])
                 for cls in range(self.hps.total_classes)],
                dtype=np.uint8)
            if not pre_buffer:
                with mask_path.open('wb') as f:
                    np.save(f, mask)
        if self.hps.n_channels != im_data.shape[0]:
            im_data = im_data[:self.hps.n_channels]
        return Image(im_id, im_data, mask[self.hps.classes])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号