def _causal_effect(
hparams, mu1, mu1s_, mu2, mu2s_, tau_cmmn, obs1, obs2, Normal, floatX):
u"""Distribution of observations.
"""
if hparams['causality'] == [1, 2]:
# ---- Model 1: x1 -> x2 ----
x1s = obs1(mu=mu1 + mu1s_)
b = Normal('b', mu=np.float32(0.),
tau=np.float32(1 / tau_cmmn[1]), dtype=floatX)
x2s = obs2(mu=mu2 + mu2s_ + b * (x1s - mu1 - mu1s_)) \
if hparams['subtract_mu_reg'] else \
obs2(mu=mu2 + mu2s_ + b * x1s)
else:
# ---- Model 2: x2 -> x1 ----
x2s = obs2(mu=mu2 + mu2s_)
b = Normal('b', mu=np.float32(0.),
tau=np.float32(1 / tau_cmmn[0]), dtype=floatX)
x1s = obs1(mu=mu1 + mu1s_ + b * (x2s - mu2 - mu2s_)) \
if hparams['subtract_mu_reg'] else \
obs1(mu=mu1 + mu1s_ + b * x2s)
return x1s, x2s, b
评论列表
文章目录