def _generate_cross_val_predict_test(X, y, est, pd_est, must_match):
def test(self):
self.assertEqual(
hasattr(est, 'predict'),
hasattr(pd_est, 'predict'))
if not hasattr(est, 'predict'):
return
pd_y_hat = pd_cross_val_predict(pd_est, X, y)
self.assertTrue(isinstance(pd_y_hat, pd.Series))
self.assertTrue(pd_y_hat.index.equals(X.index))
if must_match:
y_hat = cross_val_predict(est, X.as_matrix(), y.values)
np.testing.assert_allclose(pd_y_hat, y_hat)
return test
评论列表
文章目录