def _solve_beta(lo, hi, gamma, MT, B, f):
beta_batch = np.empty((hi - lo, f), dtype=gamma.dtype)
for ib, j in enumerate(range(lo, hi)):
m_j, idx_m_j = get_row(MT, j)
C_j = gamma[idx_m_j]
a = np.dot(m_j, C_j)
beta_batch[ib] = LA.solve(B, a)
return beta_batch
评论列表
文章目录