def _extract_labels(filename, num_classes=10):
"""???????????????
:param filename: ?????
:param num_classes: ??one-hot??????????10?
:return: 2??numpy??[index, num_classes]? ???np.float32
"""
labels = []
print('Extracting {}'.format(filename))
with gzip.GzipFile(fileobj=open(filename, 'rb')) as f:
buf = f.read()
index = 0
magic, num_labels = struct.unpack_from('>II', buf, index)
if magic != 2049:
raise ValueError('Invalid magic number {} in MNIST label file: {}'.format(magic, filename))
index += struct.calcsize('>II')
for i in range(num_labels):
label = struct.unpack_from('>B', buf, index)
index += struct.calcsize('>B')
label_one_hot = np.zeros(num_classes, dtype=np.float32)
label_one_hot[label[0]] = 1
labels.append(label_one_hot)
return np.array(labels, dtype=np.float32)
评论列表
文章目录