def generate_data(train_path, test_path):
index = 1
output_index = 1
for (dirpath, dirnames, filenames) in os.walk(input_path):
# ???????????????8?????2???
random.shuffle(filenames)
for filename in filenames:
if filename.endswith('.bmp'):
img_path = dirpath + '/' + filename
# ??opencv ????
img_data = cv2.imread(img_path)
# ??????????????28 * 28
img_data = cv2.resize(img_data, (28, 28), interpolation=cv2.INTER_AREA)
if index < 3:
cv2.imwrite(test_path + '/' + str(output_index) + '/' + str(index) + '.jpg', img_data)
index += 1
elif 10 >= index >= 3:
cv2.imwrite(train_path + '/' + str(output_index) + '/' + str(index) + '.jpg', img_data)
index += 1
if index > 10:
output_index += 1
index = 1
评论列表
文章目录