Run.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:svhn-digit-classification 作者: yanji84 项目源码 文件源码
def load():
  global onehotLabels
  global allImagesTrain
  global allLabelsTrain
  global allImagesTest
  global allLabelsTest
  global allImagesValidate
  global allLabelsValidate
  allImages = []
  allLabels = []
  with open("digit.out") as f:
    content = f.readlines()
    for line in content:
      parts = line.split(",")
      fileName = imagePath + parts[0]
      print fileName
      img = cv2.resize(cv2.imread(fileName, cv2.IMREAD_UNCHANGED), (imageSize, imageSize))

      # white wash image
      if whiteWash:
        imgMean = np.mean(img)
        #std = np.sqrt(np.sum(np.square(img - imgMean)) / (32 * 32))
        img = img.astype(np.float32)
        img -= imgMean
        #img /= std

      allImages.append(img)
      allLabels.append(parts[1])
      if debug and len(allLabels) > 1000:
        break

  onehotLabels = np.zeros((len(allLabels), 10))
  onehotLabels[np.arange(len(allLabels)), allLabels] = 1

  trainIdx = int(len(allLabels) * dataRatio[0])
  testIdx = int(trainIdx + len(allLabels) * dataRatio[1])
  allImagesTrain = allImages[:trainIdx]
  allLabelsTrain = onehotLabels[:trainIdx]
  allImagesTest = allImages[trainIdx:testIdx]
  allLabelsTest = onehotLabels[trainIdx:testIdx]
  allImagesValidate = allImages[testIdx:]
  allLabelsValidate = onehotLabels[testIdx:]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号