def test_train_pre_prepped(df_train):
num_workers = 1
params = {'num_rounds': 1}
df_grouped, j_groups = mjolnir.training.xgboost.prep_training(
df_train, num_workers)
params['groupData'] = j_groups
# TODO: This is probably not how we should make sure it isn't called..
orig_prep_training = mjolnir.training.xgboost.prep_training
try:
mjolnir.training.xgboost.prep_training = _always_raise
model = mjolnir.training.xgboost.train(df_grouped, params)
assert 0.74 == pytest.approx(model.eval(df_grouped, j_groups), abs=0.01)
finally:
mjolnir.training.xgboost.prep_training = orig_prep_training
评论列表
文章目录