ocr.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:OCR 作者: OrangeGuo 项目源码 文件源码
def train(self, training_data_array):
        for data in training_data_array:
            # ??????????
            y1 = np.dot(np.mat(self.theta1), np.mat(data.y0).T)
            sum1 = y1 + np.mat(self.input_layer_bias)
            y1 = self.sigmoid(sum1)

            y2 = np.dot(np.array(self.theta2), y1)
            y2 = np.add(y2, self.hidden_layer_bias)
            y2 = self.sigmoid(y2)

            # ??????????
            actual_vals = [0] * 10
            actual_vals[data.label] = 1
            output_errors = np.mat(actual_vals).T - np.mat(y2)
            hidden_errors = np.multiply(np.dot(np.mat(self.theta2).T, output_errors), self.sigmoid_prime(sum1))

            # ???????????
            self.theta1 += self.LEARNING_RATE * np.dot(np.mat(hidden_errors), np.mat(data.y0))
            self.theta2 += self.LEARNING_RATE * np.dot(np.mat(output_errors), np.mat(y1).T)
            self.hidden_layer_bias += self.LEARNING_RATE * output_errors
            self.input_layer_bias += self.LEARNING_RATE * hidden_errors
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号