def log_pz(self, z): log_pz = -0.5 * math.log(2.0 * math.pi) - 0.5 * z ** 2 return F.sum(log_pz, axis=1) # compute lower bound using gumbel-softmax