def split_train_test_imgs(class_names, test_ratio):
train_imgs = []
test_imgs = []
for i in class_names:
file_name = i + '.txt'
num_lines = get_num_lines(file_name)
num_test_imgs = test_ratio * num_lines
current_line = 1
with open(file_name, 'rb') as f:
for line in f:
if current_line < num_test_imgs:
test_imgs.append(line.strip())
else:
train_imgs.append(line.strip())
current_line += 1
print(str(len(train_imgs)) + ' train images')
print(str(len(test_imgs)) + ' test images')
return train_imgs, test_imgs
评论列表
文章目录