test_logistic_regression.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:dl4nlp 作者: yohokuno 项目源码 文件源码
def assertLogisticRegression(self, sampler):
        data_size = 3
        input_size = 5
        inputs = np.random.uniform(-10.0, 10.0, size=(data_size, input_size))
        outputs = np.random.randint(0, 2, size=data_size)
        initial_parameters = np.random.normal(scale=1e-5, size=input_size)

        # Create cost and gradient function for gradient descent and check its gradient
        cost_gradient = bind_cost_gradient(logistic_regression_cost_gradient,
                                           inputs, outputs, sampler=sampler)
        result = gradient_check(cost_gradient, initial_parameters)
        self.assertEqual([], result)

        # Train logistic regression and see if it predicts correct labels
        final_parameters, cost_history = gradient_descent(cost_gradient, initial_parameters, 100)
        predictions = expit(np.dot(inputs, final_parameters)) > 0.5

        # Binary classification of 3 data points with 5 dimension is always linearly separable
        for output, prediction in zip(outputs, predictions):
            self.assertEqual(output, prediction)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号