test_slda.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:slda 作者: Savvysherpa 项目源码 文件源码
def test_slda():
    l = language(10000)
    n_iter = 2000
    KL_thresh = 0.001

    nu2 = l['K']
    sigma2 = 1
    np.random.seed(l['seed'])
    eta = np.random.normal(scale=nu2, size=l['K'])
    y = [np.dot(eta, l['thetas'][i]) for i in range(l['D'])] + \
        np.random.normal(scale=sigma2, size=l['D'])
    _beta = np.repeat(0.01, l['V'])
    _mu = 0
    slda = SLDA(l['K'], l['alpha'], _beta, _mu, nu2, sigma2, n_iter,
                seed=l['seed'], n_report_iter=l['n_report_iters'])
    slda.fit(l['doc_term_matrix'], y)

    assert_probablity_distribution(slda.phi)
    check_KL_divergence(l['topics'], slda.phi, KL_thresh)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号