def get_mnist_labels(filename, num_samples, local_data_dir):
gzfname = load_or_download_mnist_files(filename, num_samples, local_data_dir)
with gzip.open(gzfname) as gz:
n = struct.unpack('I', gz.read(4))
# Read magic number.
if n[0] != 0x1080000:
raise Exception('Invalid file: unexpected magic number.')
# Read number of entries.
n = struct.unpack('>I', gz.read(4))
if n[0] != num_samples:
raise Exception('Invalid file: expected {0} rows.'.format(num_samples))
# Read labels.
res = np.fromstring(gz.read(num_samples), dtype = np.uint8)
return res.reshape((num_samples, 1))
评论列表
文章目录