def test_crp_decrement(N, alpha, seed):
A = gu.simulate_crp(N, alpha, rng=gu.gen_rng(seed))
Nk = list(np.bincount(A))
# Decrement all counts by 1.
Nk = [n-1 if n > 1 else n for n in Nk]
# Decrement rowids.
crp = simulate_crp_gpm(N, alpha, rng=gu.gen_rng(seed))
targets = [c for c in crp.counts if crp.counts[c] > 1]
seen = set([])
for r, c in crp.data.items():
if c in targets and c not in seen:
seen.add(c)
crp.unincorporate(r)
if seen == len(targets):
break
assert_crp_equality(alpha, Nk, crp)
评论列表
文章目录