def gaussian_kl_div(mean_0, cov_0, mean_1, cov_1, dim):
""" computes KL divergences between two Gaussians with given parameters"""
mean_diff = mean_1 - mean_0
cov_1_inv = tf.reciprocal(cov_1)
log_cov_1_det = tf.reduce_sum(tf.log(cov_1), axis=[1])
log_cov_0_det = tf.reduce_sum(tf.log(cov_0), axis=[1])
log_term = log_cov_1_det - log_cov_0_det
trace_term = tf.reduce_sum(cov_1_inv * cov_0, axis=[1])
square_term = tf.reduce_sum(mean_diff * cov_1_inv * mean_diff, axis=[1])
kl_div = 0.5 * (trace_term + square_term - dim + log_term)
return kl_div
评论列表
文章目录