prior.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:attend_infer_repeat 作者: akosiorek 项目源码 文件源码
def tabular_kl(p, q, zero_prob_value=0., logarg_clip=None):
    """Computes KL-divergence KL(p||q) for two probability mass functions (pmf) given in a tabular form.

    :param p: iterable
    :param q: iterable
    :param zero_prob_value: float; values below this threshold are treated as zero
    :param logarg_clip: float or None, clips the argument to the log to lie in [-logarg_clip, logarg_clip] if not None
    :return: iterable of brodcasted shape of (p * q), per-coordinate value of KL(p||q)
    """
    p, q = (tf.cast(i, tf.float64) for i in (p, q))
    non_zero = tf.greater(p, zero_prob_value)
    logarg = p / q

    if logarg_clip is not None:
        logarg = clip_preserve(logarg, 1. / logarg_clip, logarg_clip)

    log = masked_apply(logarg, tf.log, non_zero)
    kl = p * log

    return tf.cast(kl, tf.float32)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号