input_data.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:IntroToDeepLearning 作者: robb-brown 项目源码 文件源码
def extract_labels(filename, one_hot=False):
  """Extract the labels into a 1D uint8 numpy array [index]."""
  print('Extracting', filename)
  with gzip.open(filename) as bytestream:
    magic = _read32(bytestream)
    if magic != 2049:
      raise ValueError(
          'Invalid magic number %d in MNIST label file: %s' %
          (magic, filename))
    num_items = _read32(bytestream)
    buf = bytestream.read(num_items)
    labels = numpy.frombuffer(buf, dtype=numpy.uint8)
    if one_hot:
      return dense_to_one_hot(labels)
    return labels
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号