def load_data(images_url: str, labels_url: str) -> (np.array, np.array):
images_decompressed = downloader(images_url)
# Big endian 4 ????? ???? unsigned int, ?????? ?? 4 ?????
magic, size, rows, cols = struct.unpack(">IIII", images_decompressed[:16])
if magic != 2051:
print("Wrong magic for", images_url, "Probably file corrupted")
exit(2)
image_data = np.array(np.frombuffer(images_decompressed[16:], dtype=np.dtype((np.ubyte, (rows * cols,)))) / 255,
dtype=np.float32)
labels_decompressed = downloader(labels_url)
# Big endian 2 ????? ???? unsigned int, ?????? ?? 4 ?????
magic, size = struct.unpack(">II", labels_decompressed[:8])
if magic != 2049:
print("Wrong magic for", labels_url, "Probably file corrupted")
exit(2)
labels = np.frombuffer(labels_decompressed[8:], dtype=np.ubyte)
return image_data, labels
评论列表
文章目录