def compute_log_lik_exp(self, m, v, y):
if m.ndim == 2:
gh_x, gh_w = self._gh_points(GH_DEGREE)
gh_x = gh_x[:, np.newaxis, np.newaxis]
gh_w = gh_w[:, np.newaxis, np.newaxis] / np.sqrt(np.pi)
v_expand = v[np.newaxis, :, :]
m_expand = m[np.newaxis, :, :]
ts = gh_x * np.sqrt(2 * v_expand) + m_expand
logcdfs = norm.logcdf(ts * y)
prods = gh_w * logcdfs
loglik = np.sum(prods)
pdfs = norm.pdf(ts * y)
cdfs = norm.cdf(ts * y)
grad_cdfs = y * gh_w * pdfs / cdfs
dts_dm = 1
dts_dv = 0.5 * gh_x * np.sqrt(2 / v_expand)
dm = np.sum(grad_cdfs * dts_dm, axis=0)
dv = np.sum(grad_cdfs * dts_dv, axis=0)
else:
gh_x, gh_w = self._gh_points(GH_DEGREE)
gh_x = gh_x[:, np.newaxis, np.newaxis, np.newaxis]
gh_w = gh_w[:, np.newaxis, np.newaxis, np.newaxis] / np.sqrt(np.pi)
v_expand = v[np.newaxis, :, :, :]
m_expand = m[np.newaxis, :, :, :]
ts = gh_x * np.sqrt(2 * v_expand) + m_expand
logcdfs = norm.logcdf(ts * y)
prods = gh_w * logcdfs
prods_mean = np.mean(prods, axis=1)
loglik = np.sum(prods_mean)
pdfs = norm.pdf(ts * y)
cdfs = norm.cdf(ts * y)
grad_cdfs = y * gh_w * pdfs / cdfs
dts_dm = 1
dts_dv = 0.5 * gh_x * np.sqrt(2 / v_expand)
dm = np.sum(grad_cdfs * dts_dm, axis=0) / m.shape[0]
dv = np.sum(grad_cdfs * dts_dv, axis=0) / m.shape[0]
return loglik, dm, dv
评论列表
文章目录