def batch_log_pdf(self, x):
"""
Diagonal Normal log-likelihood
Ref: :py:meth:`pyro.distributions.distribution.Distribution.batch_log_pdf`
"""
# expand to patch size of input
mu = self.mu.expand(self.shape(x))
sigma = self.sigma.expand(self.shape(x))
log_pxs = -1 * (torch.log(sigma) + 0.5 * np.log(2.0 * np.pi) + 0.5 * torch.pow((x - mu) / sigma, 2))
# XXX this allows for the user to mask out certain parts of the score, for example
# when the data is a ragged tensor. also useful for KL annealing. this entire logic
# will likely be done in a better/cleaner way in the future
if self.log_pdf_mask is not None:
log_pxs = log_pxs * self.log_pdf_mask
batch_log_pdf = torch.sum(log_pxs, -1)
batch_log_pdf_shape = self.batch_shape(x) + (1,)
return batch_log_pdf.contiguous().view(batch_log_pdf_shape)
评论列表
文章目录