def tabular_kl(p, q, zero_prob_value=0., logarg_clip=None):
"""Computes KL-divergence KL(p||q) for two probability mass functions (pmf) given in a tabular form.
:param p: iterable
:param q: iterable
:param zero_prob_value: float; values below this threshold are treated as zero
:param logarg_clip: float or None, clips the argument to the log to lie in [-logarg_clip, logarg_clip] if not None
:return: iterable of brodcasted shape of (p * q), per-coordinate value of KL(p||q)
"""
p, q = (tf.cast(i, tf.float64) for i in (p, q))
non_zero = tf.greater(p, zero_prob_value)
logarg = p / q
if logarg_clip is not None:
logarg = clip_preserve(logarg, 1. / logarg_clip, logarg_clip)
log = masked_apply(logarg, tf.log, non_zero)
kl = p * log
return tf.cast(kl, tf.float32)
评论列表
文章目录