def _log_prob(self, given):
# TODO: not right when given=0 or 1
alpha, beta = self.alpha, self.beta
log_given = tf.log(given)
log_1_minus_given = tf.log(1 - given)
lgamma_alpha, lgamma_beta = tf.lgamma(alpha), tf.lgamma(beta)
lgamma_alpha_plus_beta = tf.lgamma(alpha + beta)
if self._check_numerics:
log_given = tf.check_numerics(log_given, "log(given)")
log_1_minus_given = tf.check_numerics(
log_1_minus_given, "log(1 - given)")
lgamma_alpha = tf.check_numerics(lgamma_alpha, "lgamma(alpha)")
lgamma_beta = tf.check_numerics(lgamma_beta, "lgamma(beta)")
lgamma_alpha_plus_beta = tf.check_numerics(
lgamma_alpha_plus_beta, "lgamma(alpha + beta)")
return (alpha - 1) * log_given + (beta - 1) * log_1_minus_given - (
lgamma_alpha + lgamma_beta - lgamma_alpha_plus_beta)
评论列表
文章目录