data_serialization.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:logodetect 作者: munibasad 项目源码 文件源码
def load_logo(data_dir):
    image_files = os.listdir(data_dir)
    dataset = np.ndarray(
        shape=(len(image_files), CNN_IN_HEIGHT, CNN_IN_WIDTH, CNN_IN_CH),
        dtype=np.float32)
    print(data_dir)
    num_images = 0
    for image in image_files:
        image_file = os.path.join(data_dir, image)
        try:
            image_data = (ndimage.imread(image_file).astype(float) -
                          PIXEL_DEPTH / 2) / PIXEL_DEPTH
            if image_data.shape != (CNN_IN_HEIGHT, CNN_IN_WIDTH, CNN_IN_CH):
                raise Exception('Unexpected image shape: %s' %
                                str(image_data.shape))
            dataset[num_images, :, :] = image_data
            num_images = num_images + 1
        except IOError as e:
            print('Could not read:', image_file, ':', e,
                  '-it\'s ok, skipping.')

    dataset = dataset[0:num_images, :, :]
    print('Full dataset tensor:', dataset.shape)
    print('Mean:', np.mean(dataset))
    print('Standard deviation:', np.std(dataset))
    return dataset
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号