def _log_prob(self, given):
logits = self.logits
n = tf.cast(self.n_experiments, self.param_dtype)
given = tf.cast(given, self.param_dtype)
log_1_minus_p = -tf.nn.softplus(logits)
lgamma_n_plus_1 = tf.lgamma(n + 1)
lgamma_given_plus_1 = tf.lgamma(given + 1)
lgamma_n_minus_given_plus_1 = tf.lgamma(n - given + 1)
if self._check_numerics:
lgamma_given_plus_1 = tf.check_numerics(
lgamma_given_plus_1, "lgamma(given + 1)")
lgamma_n_minus_given_plus_1 = tf.check_numerics(
lgamma_n_minus_given_plus_1, "lgamma(n - given + 1)")
return lgamma_n_plus_1 - lgamma_n_minus_given_plus_1 - \
lgamma_given_plus_1 + given * logits + n * log_1_minus_p
评论列表
文章目录