def gaussian_kl_divergence_keepbatch(self, mean, ln_var): var = F.exp(ln_var) kld = F.sum(mean * mean + var - ln_var - 1, axis=1) * 0.5 return kld