_test.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:ibex 作者: atavory 项目源码 文件源码
def _generate_cross_val_predict_test(X, y, est, pd_est, must_match):
    def test(self):
        self.assertEqual(
            hasattr(est, 'predict'),
            hasattr(pd_est, 'predict'))
        if not hasattr(est, 'predict'):
            return
        pd_y_hat = pd_cross_val_predict(pd_est, X, y)
        self.assertTrue(isinstance(pd_y_hat, pd.Series))
        self.assertTrue(pd_y_hat.index.equals(X.index))
        if must_match:
            y_hat = cross_val_predict(est, X.as_matrix(), y.values)
            np.testing.assert_allclose(pd_y_hat, y_hat)
    return test
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号