def _log_prob(self, given):
logits, temperature = self.path_param(self.logits),\
self.path_param(self.temperature)
n = tf.cast(self.n_categories, self.dtype)
log_temperature = tf.log(temperature)
if self._check_numerics:
log_temperature = tf.check_numerics(
log_temperature, "log(temperature)")
temp = logits - temperature * given
return tf.lgamma(n) + (n - 1) * log_temperature + \
tf.reduce_sum(temp, axis=-1) - \
n * tf.reduce_logsumexp(temp, axis=-1)
评论列表
文章目录