def test_grtm():
l = language(1000)
n_iter = 1000
KL_thresh = 0.3
mu = 0.
nu2 = 1.
np.random.seed(l['seed'])
H = np.random.normal(loc=mu, scale=nu2, size=(l['K'], l['K']))
zeta = pd.DataFrame([(i, j, np.dot(np.dot(l['thetas'][i], H),
l['thetas'][j]))
for i, j in product(range(l['D']), repeat=2)],
columns=('tail', 'head', 'zeta'))
zeta['y'] = (zeta.zeta >= 0).astype(int)
y = zeta[['tail', 'head', 'y']].values
skf = StratifiedKFold(y[:, 2], n_folds=100)
_, train_idx = next(iter(skf))
_K = l['K']
_alpha = l['alpha'][:_K]
_beta = np.repeat(0.01, l['V'])
_b = 1.
grtm = GRTM(_K, _alpha, _beta, mu, nu2, _b, n_iter, seed=l['seed'],
n_report_iter=l['n_report_iters'])
grtm.fit(l['doc_term_matrix'], y[train_idx])
assert_probablity_distribution(grtm.phi)
check_KL_divergence(l['topics'], grtm.phi, KL_thresh)
评论列表
文章目录