def _extract_labels(filename, num_labels):
"""Extract the labels into a vector of int64 label IDs.
Args:
filename: The path to an MNIST labels file.
num_labels: The number of labels in the file.
Returns:
A numpy array of shape [number_of_labels]
"""
print('Extracting labels from: ', filename)
with gzip.open(filename) as bytestream:
bytestream.read(8)
buf = bytestream.read(1 * num_labels)
labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
return labels
评论列表
文章目录