def test_split(spark_context, hive_context):
df = (
hive_context
.range(1, 100 * 100)
# convert into 100 "queries" with 100 values each. We need a
# sufficiently large number of queries, or the split wont have
# enough data for partitions to even out.
.select(F.lit('foowiki').alias('wikiid'),
(F.col('id')/100).cast('int').alias('norm_query_id')))
with_folds = mjolnir.training.tuning.split(df, (0.8, 0.2), num_partitions=4).collect()
fold_0 = [row for row in with_folds if row.fold == 0]
fold_1 = [row for row in with_folds if row.fold == 1]
# Check the folds are pretty close to requested
total_len = float(len(with_folds))
assert 0.8 == pytest.approx(len(fold_0) / total_len, abs=0.015)
assert 0.2 == pytest.approx(len(fold_1) / total_len, abs=0.015)
# Check each norm query is only found on one side of the split
queries_in_0 = set([row.norm_query_id for row in fold_0])
queries_in_1 = set([row.norm_query_id for row in fold_1])
assert len(queries_in_0.intersection(queries_in_1)) == 0
评论列表
文章目录