def log_marginal(self, y, h, py, q):
'''Computes the approximate log marginal.
Uses \log \sum p / q - \log N
Args:
y: T.tensor, target values.
h: T.tensor, latent samples.
py: T.tesnor, conditional density p(y | h)
q: approximate posterior q(h | y)
Returns:
approximate log marginal.
'''
log_py_h = -self.conditional.neg_log_prob(y, py)
log_ph = -self.prior.neg_log_prob(h)
log_qh = -self.posterior.neg_log_prob(h, q)
assert log_py_h.ndim == log_ph.ndim == log_qh.ndim
log_p = log_py_h + log_ph - log_qh
log_p_max = T.max(log_p, axis=0, keepdims=True)
w = T.exp(log_p - log_p_max)
return (T.log(w.mean(axis=0, keepdims=True)) + log_p_max).mean()
评论列表
文章目录