def test_top_and_bottom_with_groupby_and_mask(self, dtype, seed):
permute = partial(permute_rows, seed)
permuted_array = compose(permute, partial(array, dtype=int64_dtype))
shape = (8, 8)
# Shuffle the input rows to verify that we correctly pick out the top
# values independently of order.
factor_data = permute(arange(0, 64, dtype=dtype).reshape(shape))
classifier_data = permuted_array([[0, 0, 1, 1, 2, 2, 0, 0],
[0, 0, 1, 1, 2, 2, 0, 0],
[0, 1, 2, 3, 0, 1, 2, 3],
[0, 1, 2, 3, 0, 1, 2, 3],
[0, 0, 0, 0, 1, 1, 1, 1],
[0, 0, 0, 0, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0]])
f = self.f
c = self.c
self.check_terms(
terms={
'top2': f.top(2, groupby=c),
'bottom2': f.bottom(2, groupby=c),
},
initial_workspace={
f: factor_data,
c: classifier_data,
},
expected={
# Should be the rightmost two entries in classifier_data,
# ignoring the off-diagonal.
'top2': permuted_array([[0, 1, 1, 1, 1, 1, 1, 0],
[0, 1, 1, 1, 1, 1, 0, 1],
[1, 1, 1, 1, 1, 0, 1, 1],
[1, 1, 1, 1, 0, 1, 1, 1],
[0, 1, 1, 0, 0, 0, 1, 1],
[0, 1, 0, 1, 0, 0, 1, 1],
[0, 0, 0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 0, 0, 1, 1]], dtype=bool),
# Should be the rightmost two entries in classifier_data,
# ignoring the off-diagonal.
'bottom2': permuted_array([[1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 0, 1, 1],
[1, 1, 1, 1, 0, 1, 1, 1],
[1, 1, 0, 0, 1, 1, 0, 0],
[1, 1, 0, 0, 1, 1, 0, 0],
[1, 0, 1, 0, 0, 0, 0, 0],
[0, 1, 1, 0, 0, 0, 0, 0]],
dtype=bool),
},
mask=self.build_mask(permute(rot90(self.eye_mask(shape=shape)))),
)
评论列表
文章目录