def test_observed_perplexity(N, V, C, M):
set_random_seed(make_seed(N, V, C, M))
model = generate_fake_model(N, V, C, M)
config = TINY_CONFIG.copy()
config['model_num_clusters'] = M
model['config'] = config
server = TreeCatServer(model)
for count in [1, 2, 3]:
if count > 1 and C > 2:
continue # NotImplementedError.
counts = 1
perplexity = server.observed_perplexity(counts)
print(perplexity)
assert perplexity.shape == (V, )
assert np.all(1 <= perplexity)
assert np.all(perplexity <= count * C)
评论列表
文章目录