def test_server_median(N, V, C, M):
model = generate_fake_model(N, V, C, M)
config = TINY_CONFIG.copy()
config['model_num_clusters'] = M
model['config'] = config
server = TreeCatServer(model)
# Evaluate on random data.
counts = np.random.randint(10, size=[V], dtype=np.int8)
table = generate_dataset(N, V, C)['table']
median = server.median(counts, table.data)
assert median.shape == table.data.shape
assert median.dtype == np.int8
for v in range(V):
beg, end = table.ragged_index[v:v + 2]
totals = median[:, beg:end].sum(axis=1)
assert np.all(totals == counts[v])
评论列表
文章目录