def _log_prob(self, given):
given, alpha = maybe_explicit_broadcast(
given, self.alpha, 'given', 'alpha')
lbeta_alpha = tf.lbeta(alpha)
# fix of no static shape inference for tf.lbeta
if alpha.get_shape():
lbeta_alpha.set_shape(alpha.get_shape()[:-1])
log_given = tf.log(given)
if self._check_numerics:
lbeta_alpha = tf.check_numerics(lbeta_alpha, "lbeta(alpha)")
log_given = tf.check_numerics(log_given, "log(given)")
log_p = -lbeta_alpha + tf.reduce_sum((alpha - 1) * log_given, -1)
return log_p
评论列表
文章目录