vrnn_model.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:vrnn 作者: frhrdr 项目源码 文件源码
def gaussian_kl_div(mean_0, cov_0, mean_1, cov_1, dim):
    """ computes KL divergences between two Gaussians with given parameters"""
    mean_diff = mean_1 - mean_0
    cov_1_inv = tf.reciprocal(cov_1)
    log_cov_1_det = tf.reduce_sum(tf.log(cov_1), axis=[1])
    log_cov_0_det = tf.reduce_sum(tf.log(cov_0), axis=[1])
    log_term = log_cov_1_det - log_cov_0_det
    trace_term = tf.reduce_sum(cov_1_inv * cov_0, axis=[1])
    square_term = tf.reduce_sum(mean_diff * cov_1_inv * mean_diff, axis=[1])
    kl_div = 0.5 * (trace_term + square_term - dim + log_term)
    return kl_div
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号