def _log_prob(self, given):
rate = self.rate
given = tf.cast(given, self.param_dtype)
log_rate = tf.log(rate)
lgamma_given_plus_1 = tf.lgamma(given + 1)
if self._check_numerics:
log_rate = tf.check_numerics(log_rate, "log(rate)")
lgamma_given_plus_1 = tf.check_numerics(
lgamma_given_plus_1, "lgamma(given + 1)")
return given * log_rate - rate - lgamma_given_plus_1
评论列表
文章目录