data_loading_module.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:road-segmentation 作者: paramoecium 项目源码 文件源码
def extract_labels(filename_base, num_images, num_of_transformations=6, patch_size=const.IMG_PATCH_SIZE,
                   patch_stride=const.IMG_PATCH_STRIDE):
    """Extract the labels into a 1-hot matrix [image index, label index]."""
    gt_imgs = []
    for i in range(1, num_images+1):
        imageid = "satImage_%.3d" % i
        image_filename = filename_base + imageid + ".png"
        if os.path.isfile(image_filename):
            print('Loading ' + image_filename)
            img = mpimg.imread(image_filename)
            gt_imgs.append(img)
        else:
            print('File ' + image_filename + ' does not exist')

    num_images = len(gt_imgs)
    print('Extracting patches...')
    gt_patches = [pem.label_img_crop(gt_imgs[i], patch_size, patch_stride, num_of_transformations)
                  for i in range(num_images)]
    data = np.asarray([gt_patches[i][j] for i in range(len(gt_patches)) for j in range(len(gt_patches[i]))])
    labels = np.asarray([value_to_class(np.mean(data[i])) for i in range(len(data))])
    print(str(len(data)) + ' label patches extracted.')

    # Convert to dense 1-hot representation.
    return labels.astype(np.float32)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号