test_logistic_regression.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:dl4nlp 作者: yohokuno 项目源码 文件源码
def test_logistic_regression(self):
        input = np.random.uniform(-10.0, 10.0, size=10)
        output = np.random.randint(0, 2)

        def logistic_regression_wrapper(parameters):
            return logistic_regression_cost_gradient(parameters, input, output)

        initial_parameters = np.random.normal(scale=1e-5, size=10)
        result = gradient_check(logistic_regression_wrapper, initial_parameters)
        self.assertEqual([], result)

        # Train logistic regression and see if it predicts correct label
        final_parameters, cost_history = gradient_descent(logistic_regression_wrapper, initial_parameters, 100)
        prediction = expit(np.dot(input, final_parameters)) > 0.5
        self.assertEqual(output, prediction)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号