def test_slda():
l = language(10000)
n_iter = 2000
KL_thresh = 0.001
nu2 = l['K']
sigma2 = 1
np.random.seed(l['seed'])
eta = np.random.normal(scale=nu2, size=l['K'])
y = [np.dot(eta, l['thetas'][i]) for i in range(l['D'])] + \
np.random.normal(scale=sigma2, size=l['D'])
_beta = np.repeat(0.01, l['V'])
_mu = 0
slda = SLDA(l['K'], l['alpha'], _beta, _mu, nu2, sigma2, n_iter,
seed=l['seed'], n_report_iter=l['n_report_iters'])
slda.fit(l['doc_term_matrix'], y)
assert_probablity_distribution(slda.phi)
check_KL_divergence(l['topics'], slda.phi, KL_thresh)
评论列表
文章目录