def log_p_multinomial(X, p):
"""log_p_multinomial returns the probability of drawing a vector of
counts from a multinomial parameterized by p.
Args:
X: np.array of counts representing a multinomial draw.
p: np.array of probabilities, corresponding to X
Returns:
the log probability of the multinomial draw.
"""
if sum(X) == 0:
return 0.0
# check that input is valid
assert len(X) == len(p)
eps = 0.0001
assert abs(sum(p)- 1.0) < eps
# calculate log prob.
log_n_choices = special.gammaln(sum(X)+1) - sum([special.gammaln(x_i+1)
for x_i in X])
log_p_items = sum(x_i*np.log(p_i) for (x_i, p_i) in zip(X, p))
return log_n_choices + log_p_items
评论列表
文章目录