serving_test.py 文件源码

python
阅读 41 收藏 0 点赞 0 评论 0

项目:treecat 作者: posterior 项目源码 文件源码
def validate_sample_shape(table, server):
    # Sample many different counts patterns.
    V = table.num_cols
    N = table.num_rows
    factors = [[0, 1, 2]] * V
    for counts in itertools.product(*factors):
        counts = np.array(counts, dtype=np.int8)
        for n in range(N):
            row = table.data[n, :]
            samples = server.sample(N, counts, row)
            assert samples.shape == (N, row.shape[0])
            assert samples.dtype == row.dtype
            for v in range(V):
                beg, end = table.ragged_index[v:v + 2]
                assert np.all(samples[:, beg:end].sum(axis=1) == counts[v])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号