def _read_datafile(self, path, expected_dims):
"""Helper function to read a file in IDX format."""
base_magic_num = 2048
with gzip.GzipFile(path) as f:
magic_num = struct.unpack('>I', f.read(4))[0]
expected_magic_num = base_magic_num + expected_dims
if magic_num != expected_magic_num:
raise ValueError('Incorrect MNIST magic number (expected '
'{}, got {})'
.format(expected_magic_num, magic_num))
dims = struct.unpack('>' + 'I' * expected_dims,
f.read(4 * expected_dims))
buf = f.read(reduce(operator.mul, dims))
data = np.frombuffer(buf, dtype=np.uint8)
data = data.reshape(*dims)
return data
评论列表
文章目录