def test_blslda():
l = language(10000)
n_iter = 1500
KL_thresh = 0.03
mu = 0.
nu2 = 1.
np.random.seed(l['seed'])
eta = np.random.normal(loc=mu, scale=nu2, size=l['K'])
zeta = np.array([np.dot(eta, l['thetas'][i]) for i in range(l['D'])])
y = (zeta >= 0).astype(int)
_beta = np.repeat(0.01, l['V'])
_b = 7.25
blslda = BLSLDA(l['K'], l['alpha'], _beta, mu, nu2, _b, n_iter,
seed=l['seed'],
n_report_iter=l['n_report_iters'])
blslda.fit(l['doc_term_matrix'], y)
assert_probablity_distribution(blslda.phi)
check_KL_divergence(l['topics'], blslda.phi, KL_thresh)
评论列表
文章目录