def __init__(self,
saved_model=None,
train_folder=None,
feature=_feature.__func__):
"""
:param saved_model: optional saved train set and labels as .npz
:param train_folder: optional custom train data to process
:param feature: feature function - compatible with saved_model
"""
self.feature = feature
if train_folder is not None:
self.train_set, self.train_labels, self.model = \
self.create_model(train_folder)
else:
if cv2.__version__[0] == '2':
self.model = cv2.KNearest()
else:
self.model = cv2.ml.KNearest_create()
if saved_model is None:
saved_model = TRAIN_DATA+'raw_pixel_data.npz'
with np.load(saved_model) as data:
self.train_set = data['train_set']
self.train_labels = data['train_labels']
if cv2.__version__[0] == '2':
self.model.train(self.train_set, self.train_labels)
else:
self.model.train(self.train_set, cv2.ml.ROW_SAMPLE,
self.train_labels)
python类KNearest()的实例源码
def create_model(self, train_folder):
"""
Return the training set, its labels and the trained model
:param train_folder: folder where to retrieve data
:return: (train_set, train_labels, trained_model)
"""
digits = []
labels = []
for n in range(1, 10):
folder = train_folder + str(n)
samples = [pic for pic in os.listdir(folder)
if os.path.isfile(os.path.join(folder, pic))]
for sample in samples:
image = cv2.imread(os.path.join(folder, sample))
# Expecting black on white
image = 255 - cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
_, image = cv2.threshold(image, 0, 255,
cv2.THRESH_BINARY + cv2.THRESH_OTSU)
feat = self.feature(image)
digits.append(feat)
labels.append(n)
digits = np.array(digits, np.float32)
labels = np.array(labels, np.float32)
if cv2.__version__[0] == '2':
model = cv2.KNearest()
model.train(digits, labels)
else:
model = cv2.ml.KNearest_create()
model.train(digits, cv2.ml.ROW_SAMPLE, labels)
return digits, labels, model
def __init__(self):
collect_dir = 'captcha/collect'
label = []
train_file = []
for i in os.listdir(collect_dir):
for y in os.listdir(collect_dir + '/' + i):
#print i
label.append(ord(i))
#print y
train_file.append(collect_dir + '/' + i + '/' + y)
train_data = [cv2.imread(i, 0) for i in train_file]
train = np.array(train_data).reshape(-1, 400).astype(np.float32)
label = np.array(label).reshape(-1)
self.knn = cv2.KNearest()
self.knn.train(train, label)
def get_matrix():
chessboard_matrix=init_matrix()
train_filename_list,train_label_list=train.find_picture('static/train/')
train_file_list = train.preprocess_img(train_filename_list)
test_filename_list,test_label_list=train.find_picture('static/ClippedImg/')
test_file_list = train.preprocess_img(test_filename_list)
knn = cv2.KNearest()
knn.train(train_file_list,train_label_list)
ret,result,neighbours,dist = knn.find_nearest(test_file_list,k=3)
for i in range(len(result)):
#print int(result[i][0]),test_filename_list[i]
if int(result[i][0]) ==1:
position = re.search(r'\d{1,2}_\d{1,2}',str(test_filename_list[i])).group().split('_')
xposition,yposition = int(position[0]),int(position[1])
chessboard_matrix[xposition][yposition]=1
if int(result[i][0]) ==2:
position = re.search(r'\d{1,2}_\d{1,2}',str(test_filename_list[i])).group().split('_')
xposition,yposition = int(position[0]),int(position[1])
chessboard_matrix[xposition][yposition]=2
# if str(test_filename_list[i]).find('2_2') != -1:
# print result[i],test_filename_list[i]
return chessboard_matrix