test_slda.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:slda 作者: Savvysherpa 项目源码 文件源码
def test_grtm():
    l = language(1000)
    n_iter = 1000
    KL_thresh = 0.3

    mu = 0.
    nu2 = 1.
    np.random.seed(l['seed'])
    H = np.random.normal(loc=mu, scale=nu2, size=(l['K'], l['K']))
    zeta = pd.DataFrame([(i, j, np.dot(np.dot(l['thetas'][i], H),
                                       l['thetas'][j]))
                         for i, j in product(range(l['D']), repeat=2)],
                        columns=('tail', 'head', 'zeta'))
    zeta['y'] = (zeta.zeta >= 0).astype(int)
    y = zeta[['tail', 'head', 'y']].values
    skf = StratifiedKFold(y[:, 2], n_folds=100)
    _, train_idx = next(iter(skf))
    _K = l['K']
    _alpha = l['alpha'][:_K]
    _beta = np.repeat(0.01, l['V'])
    _b = 1.
    grtm = GRTM(_K, _alpha, _beta, mu, nu2, _b, n_iter, seed=l['seed'],
                n_report_iter=l['n_report_iters'])
    grtm.fit(l['doc_term_matrix'], y[train_idx])

    assert_probablity_distribution(grtm.phi)
    check_KL_divergence(l['topics'], grtm.phi, KL_thresh)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号