def _solve_topic(lo, hi, theta, alpha, gamma, XT, BTBpR, c0, c1, f, lam_d, lam_t):
topic_batch = np.empty((hi - lo, f), dtype=theta.dtype)
for ib, u in enumerate(range(lo, hi)):
x_u, idx_u = get_row(XT, u)
B_u = theta[idx_u]
cpAT = gamma[u].dot(alpha.T)
a = lam_d * x_u.dot(c1 * B_u) + lam_t * cpAT
'''
non-zero elements handled in this loop
'''
B = BTBpR + B_u.T.dot((c1 - c0) * B_u)#B_u only contains vectors corresponding to non-zero doc-word occurence
topic_batch[ib] = LA.solve(B, a)
topic_batch = topic_batch.clip(0)
return topic_batch
评论列表
文章目录