def test_server_logprob_normalized(N, V, C, M):
model = generate_fake_model(N, V, C, M)
config = TINY_CONFIG.copy()
config['model_num_clusters'] = M
model['config'] = config
server = TreeCatServer(model)
# The total probability of all categorical rows should be 1.
ragged_index = model['suffstats']['ragged_index']
factors = []
for v in range(V):
C = ragged_index[v + 1] - ragged_index[v]
factors.append([one_hot(c, C) for c in range(C)])
data = np.array(
[np.concatenate(columns) for columns in itertools.product(*factors)],
dtype=np.int8)
logprobs = server.logprob(data)
logtotal = np.logaddexp.reduce(logprobs)
assert logtotal == pytest.approx(0.0, abs=1e-5)
评论列表
文章目录