test_filter.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:catalyst 作者: enigmampc 项目源码 文件源码
def test_top_and_bottom_with_groupby_and_mask(self, dtype, seed):
        permute = partial(permute_rows, seed)
        permuted_array = compose(permute, partial(array, dtype=int64_dtype))

        shape = (8, 8)

        # Shuffle the input rows to verify that we correctly pick out the top
        # values independently of order.
        factor_data = permute(arange(0, 64, dtype=dtype).reshape(shape))
        classifier_data = permuted_array([[0, 0, 1, 1, 2, 2, 0, 0],
                                          [0, 0, 1, 1, 2, 2, 0, 0],
                                          [0, 1, 2, 3, 0, 1, 2, 3],
                                          [0, 1, 2, 3, 0, 1, 2, 3],
                                          [0, 0, 0, 0, 1, 1, 1, 1],
                                          [0, 0, 0, 0, 1, 1, 1, 1],
                                          [0, 0, 0, 0, 0, 0, 0, 0],
                                          [0, 0, 0, 0, 0, 0, 0, 0]])

        f = self.f
        c = self.c

        self.check_terms(
            terms={
                'top2': f.top(2, groupby=c),
                'bottom2': f.bottom(2, groupby=c),
            },
            initial_workspace={
                f: factor_data,
                c: classifier_data,
            },
            expected={
                # Should be the rightmost two entries in classifier_data,
                # ignoring the off-diagonal.
                'top2': permuted_array([[0, 1, 1, 1, 1, 1, 1, 0],
                                        [0, 1, 1, 1, 1, 1, 0, 1],
                                        [1, 1, 1, 1, 1, 0, 1, 1],
                                        [1, 1, 1, 1, 0, 1, 1, 1],
                                        [0, 1, 1, 0, 0, 0, 1, 1],
                                        [0, 1, 0, 1, 0, 0, 1, 1],
                                        [0, 0, 0, 0, 0, 0, 1, 1],
                                        [0, 0, 0, 0, 0, 0, 1, 1]], dtype=bool),
                # Should be the rightmost two entries in classifier_data,
                # ignoring the off-diagonal.
                'bottom2': permuted_array([[1, 1, 1, 1, 1, 1, 0, 0],
                                           [1, 1, 1, 1, 1, 1, 0, 0],
                                           [1, 1, 1, 1, 1, 0, 1, 1],
                                           [1, 1, 1, 1, 0, 1, 1, 1],
                                           [1, 1, 0, 0, 1, 1, 0, 0],
                                           [1, 1, 0, 0, 1, 1, 0, 0],
                                           [1, 0, 1, 0, 0, 0, 0, 0],
                                           [0, 1, 1, 0, 0, 0, 0, 0]],
                                          dtype=bool),
            },
            mask=self.build_mask(permute(rot90(self.eye_mask(shape=shape)))),
        )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号