def retrieve_data_set():
"""Retrieve data from all the .npz files and aggregate it into a
data set for mlp training"""
start_time = cv2.getTickCount()
print("Loading data set...")
image_array = np.zeros((1, 38400), 'float')
label_array = np.zeros((1, 4), 'float')
# Retrieve a list of pathname that matches the below expr
data_set = glob.glob("data_set/*.npz")
if not data_set:
print("No data set in directory, exiting!")
sys.exit()
for single_npz in data_set:
with np.load(single_npz) as data:
temp_images = data["images"]
temp_labels = data["labels"]
image_array = np.vstack((image_array, temp_images))
label_array = np.vstack((label_array, temp_labels))
X = np.float32(image_array[1:, :])
Y = np.float32(label_array[1:, :])
print("Image array shape: {0}".format(X.shape))
print("Label array shape: {0}".format(Y.shape))
end_time = cv2.getTickCount()
print("Data set load duration: {0}"
.format((end_time - start_time) // cv2.getTickFrequency()))
return X, Y
评论列表
文章目录