def kl_sym(self, old_dist_info_vars, new_dist_info_vars):
old_means = old_dist_info_vars["mean"]
old_log_stds = old_dist_info_vars["log_std"]
new_means = new_dist_info_vars["mean"]
new_log_stds = new_dist_info_vars["log_std"]
"""
Compute the KL divergence of two multivariate Gaussian distribution with
diagonal covariance matrices
"""
old_std = TT.exp(old_log_stds)
new_std = TT.exp(new_log_stds)
# means: (N*A)
# std: (N*A)
# formula:
# { (\mu_1 - \mu_2)^2 + \sigma_1^2 - \sigma_2^2 } / (2\sigma_2^2) +
# ln(\sigma_2/\sigma_1)
numerator = TT.square(old_means - new_means) + \
TT.square(old_std) - TT.square(new_std)
denominator = 2 * TT.square(new_std) + 1e-8
return TT.sum(
numerator / denominator + new_log_stds - old_log_stds, axis=-1)
评论列表
文章目录