def step_infer(self, r, q, y, *params):
'''Step inference function for IRVI.inference scan.
Args:
r: theano randomstream variable
q: T.tensor. Current approximate posterior parameters
y: T.tensor. Data sample
params: list of shared variables
Returns:
q: T.tensor. New approximate posterior parameters
cost: T.scalar float. Negative lower bound of current parameters
'''
model = self.model
prior_params = model.get_prior_params(*params)
h = (r <= q[None, :, :]).astype(floatX)
py = model.p_y_given_h(h, *params)
log_py_h = -model.conditional.neg_log_prob(y[None, :, :], py)
log_ph = -model.prior.step_neg_log_prob(h, *prior_params)
log_qh = -model.posterior.neg_log_prob(h, q[None, :, :])
log_p = log_py_h + log_ph - log_qh
w_tilde = get_w_tilde(log_p)
cost = -log_p.mean()
q_ = (w_tilde[:, :, None] * h).sum(axis=0)
q = self.inference_rate * q_ + (1 - self.inference_rate) * q
return q, cost
评论列表
文章目录