def test_latent_correlation(N, V, C, M):
set_random_seed(make_seed(N, V, C, M))
model = generate_fake_model(N, V, C, M)
config = TINY_CONFIG.copy()
config['model_num_clusters'] = M
model['config'] = config
server = TreeCatServer(model)
correlation = server.latent_correlation()
print(correlation)
assert np.all(0 <= correlation)
assert np.all(correlation <= 1)
assert np.allclose(correlation, correlation.T)
for v in range(V):
assert correlation[v, :].argmax() == v
assert correlation[:, v].argmax() == v
评论列表
文章目录