def logprob_dc(counts, prior, axis=None):
"""Non-normalized log probability of a Dirichlet-Categorical distribution.
See https://en.wikipedia.org/wiki/Dirichlet-multinomial_distribution
"""
# Note that this excludes the factorial(counts) term, since we explicitly
# track permutations in assignments.
return gammaln(np.add(counts, prior, dtype=np.float32)).sum(axis)
评论列表
文章目录