loader.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:neural_network_habr_guide 作者: m9psy 项目源码 文件源码
def load_data(images_url: str, labels_url: str) -> (np.array, np.array):
    images_decompressed = downloader(images_url)

    # Big endian 4 ????? ???? unsigned int, ?????? ?? 4 ?????
    magic, size, rows, cols = struct.unpack(">IIII", images_decompressed[:16])
    if magic != 2051:
        print("Wrong magic for", images_url, "Probably file corrupted")
        exit(2)

    image_data = np.array(np.frombuffer(images_decompressed[16:], dtype=np.dtype((np.ubyte, (rows * cols,)))) / 255,
                          dtype=np.float32)

    labels_decompressed = downloader(labels_url)
    # Big endian 2 ????? ???? unsigned int, ?????? ?? 4 ?????
    magic, size = struct.unpack(">II", labels_decompressed[:8])
    if magic != 2049:
        print("Wrong magic for", labels_url, "Probably file corrupted")
        exit(2)

    labels = np.frombuffer(labels_decompressed[8:], dtype=np.ubyte)

    return image_data, labels
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号