def gauss_log_prob(means, logstds, x): var = th.exp(2 * logstds) top = (-(x - means)**2) bottom = (2 * var) - 0.5 * LOG2PI - logstds gp = top / bottom return th.sum(gp, dim=1)