def load():
global onehotLabels
global allImagesTrain
global allLabelsTrain
global allImagesTest
global allLabelsTest
global allImagesValidate
global allLabelsValidate
allImages = []
allLabels = []
with open("digit.out") as f:
content = f.readlines()
for line in content:
parts = line.split(",")
fileName = imagePath + parts[0]
print fileName
img = cv2.resize(cv2.imread(fileName, cv2.IMREAD_UNCHANGED), (imageSize, imageSize))
# white wash image
if whiteWash:
imgMean = np.mean(img)
#std = np.sqrt(np.sum(np.square(img - imgMean)) / (32 * 32))
img = img.astype(np.float32)
img -= imgMean
#img /= std
allImages.append(img)
allLabels.append(parts[1])
if debug and len(allLabels) > 1000:
break
onehotLabels = np.zeros((len(allLabels), 10))
onehotLabels[np.arange(len(allLabels)), allLabels] = 1
trainIdx = int(len(allLabels) * dataRatio[0])
testIdx = int(trainIdx + len(allLabels) * dataRatio[1])
allImagesTrain = allImages[:trainIdx]
allLabelsTrain = onehotLabels[:trainIdx]
allImagesTest = allImages[trainIdx:testIdx]
allLabelsTest = onehotLabels[trainIdx:testIdx]
allImagesValidate = allImages[testIdx:]
allLabelsValidate = onehotLabels[testIdx:]
评论列表
文章目录