digit_classifier.py 文件源码

python
阅读 53 收藏 0 点赞 0 评论 0

项目:pyku 作者: dubvulture 项目源码 文件源码
def create_model(self, train_folder):
        """
        Return the training set, its labels and the trained model
        :param train_folder: folder where to retrieve data
        :return: (train_set, train_labels, trained_model)
        """
        digits = []
        labels = []
        for n in range(1, 10):
            folder = train_folder + str(n)
            samples = [pic for pic in os.listdir(folder)
                       if os.path.isfile(os.path.join(folder, pic))]

            for sample in samples:
                image = cv2.imread(os.path.join(folder, sample))
                # Expecting black on white
                image = 255 - cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
                _, image = cv2.threshold(image, 0, 255,
                                         cv2.THRESH_BINARY + cv2.THRESH_OTSU)
                feat = self.feature(image)
                digits.append(feat)
                labels.append(n)

        digits = np.array(digits, np.float32)
        labels = np.array(labels, np.float32)
        if cv2.__version__[0] == '2':
            model = cv2.KNearest()
            model.train(digits, labels)
        else:
            model = cv2.ml.KNearest_create()
            model.train(digits, cv2.ml.ROW_SAMPLE, labels)
        return digits, labels, model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号