def read_data(self, train_split=0.80, dev_split=0.10, test_split=0.10):
"""
Class function to read images from `self.image_dir`and split them into three groups: train/dev/test.
"""
assert (train_split + dev_split + test_split == 1.0)
all_images = glob.glob(self.image_dir + "*.png")
data = []
for image_path in all_images:
image = imread(image_path, flatten=True)
image = image.reshape(IMAGE_WIDTH*IMAGE_HEIGHT)
# image = np.multiply(image, 1.0 / 255.0) No scaling here
data.append(image)
data = np.array(data)
data = data.astype(np.uint8)
total_images = data.shape[0]
train_limit = int(total_images * train_split)
dev_limit = train_limit + int(total_images * dev_split)
self.train = data[:train_limit]
self.dev = data[train_limit:dev_limit]
self.test = data[dev_limit:]
# Only shuffling training data.
random.shuffle(self.train)
self.data_dict = {
'train': self.train,
'dev': self.dev,
'test': self.test}
评论列表
文章目录