def test_server_marginals(N, V, C, M):
model = generate_fake_model(N, V, C, M)
config = TINY_CONFIG.copy()
config['model_num_clusters'] = M
model['config'] = config
server = TreeCatServer(model)
# Evaluate on random data.
table = generate_dataset(N, V, C)['table']
marginals = server.marginals(table.data)
for v in range(V):
beg, end = table.ragged_index[v:v + 2]
totals = marginals[:, beg:end].sum(axis=1)
assert np.allclose(totals, 1.0)
评论列表
文章目录