def tf_kl_divergence(self, distr_params1, distr_params2):
alpha1, beta1, alpha_beta1, log_norm1 = distr_params1
alpha2, beta2, alpha_beta2, log_norm2 = distr_params2
return log_norm2 - log_norm1 - tf.digamma(x=beta1) * (beta2 - beta1) - \
tf.digamma(x=alpha1) * (alpha2 - alpha1) + tf.digamma(x=alpha_beta1) * (alpha_beta2 - alpha_beta1)
评论列表
文章目录