def validate_sample_shape(table, server):
# Sample many different counts patterns.
V = table.num_cols
N = table.num_rows
factors = [[0, 1, 2]] * V
for counts in itertools.product(*factors):
counts = np.array(counts, dtype=np.int8)
for n in range(N):
row = table.data[n, :]
samples = server.sample(N, counts, row)
assert samples.shape == (N, row.shape[0])
assert samples.dtype == row.dtype
for v in range(V):
beg, end = table.ragged_index[v:v + 2]
assert np.all(samples[:, beg:end].sum(axis=1) == counts[v])
评论列表
文章目录