def gaussian_kl_divergence_keepbatch(self, mean, ln_var): var = F.exp(ln_var) kld = F.sum(mean ** 2 + var - ln_var - 1, axis=1) * 0.5 return kld