stochastic_graph_test.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def testTraversesControlInputs(self):
    dt1 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.))
    logits = dt1.value() * 3.
    dt2 = st.StochasticTensor(distributions.Bernoulli(logits=logits))
    dt3 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.))
    x = dt3.value()
    y = array_ops.ones((2, 2)) * 4.
    z = array_ops.ones((2, 2)) * 3.
    out = control_flow_ops.cond(
        math_ops.cast(dt2, dtypes.bool), lambda: math_ops.add(x, y),
        lambda: math_ops.square(z))
    out += 5.
    dep_map = sg._stochastic_dependencies_map([out])
    self.assertEqual(dep_map[dt1], set([out]))
    self.assertEqual(dep_map[dt2], set([out]))
    self.assertEqual(dep_map[dt3], set([out]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号