def _solve_gamma(lo, hi, alpha, beta, topic, M, B, c0, c1, f, lam_w, lam_t):
gamma_batch = np.empty((hi - lo, f), dtype=alpha.dtype)
for ib, i in enumerate(range(lo, hi)):
t_i = topic[i,:]
m_i, idx_m_i = get_row(M, i)
B_i = beta[idx_m_i]
'''
the reason why they put G_i in the loop instead of calculate GTG = gamma.T * gamma is that in the objective function,
we currently only consider the non-zero elements in matrix W.
'''
a = lam_t * np.dot(t_i, alpha) + lam_w * np.dot(m_i, B_i)
gamma_batch[ib] = LA.solve(B, a)
return gamma_batch
评论列表
文章目录