def testTraversesControlInputs(self):
dt1 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.))
logits = dt1.value() * 3.
dt2 = st.StochasticTensor(distributions.Bernoulli(logits=logits))
dt3 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.))
x = dt3.value()
y = array_ops.ones((2, 2)) * 4.
z = array_ops.ones((2, 2)) * 3.
out = control_flow_ops.cond(
math_ops.cast(dt2, dtypes.bool), lambda: math_ops.add(x, y),
lambda: math_ops.square(z))
out += 5.
dep_map = sg._stochastic_dependencies_map([out])
self.assertEqual(dep_map[dt1], set([out]))
self.assertEqual(dep_map[dt2], set([out]))
self.assertEqual(dep_map[dt3], set([out]))
stochastic_graph_test.py 文件源码
python
阅读 18
收藏 0
点赞 0
评论 0
评论列表
文章目录