stochastic_graph_test.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def testExplicitStochasticTensors(self):
    with self.test_session() as sess:
      mu = constant_op.constant([0.0, 0.1, 0.2])
      sigma = constant_op.constant([1.1, 1.2, 1.3])
      with st.value_type(st.SampleValue()):
        dt1 = st.StochasticTensor(NormalNotParam(loc=mu, scale=sigma))
        dt2 = st.StochasticTensor(NormalNotParam(loc=mu, scale=sigma))
        loss = math_ops.square(array_ops.identity(dt1)) + 10. + dt2

        sl_all = sg.surrogate_loss([loss])
        sl_dt1 = sg.surrogate_loss([loss], stochastic_tensors=[dt1])
        sl_dt2 = sg.surrogate_loss([loss], stochastic_tensors=[dt2])

        dt1_term = dt1.distribution.log_prob(dt1) * loss
        dt2_term = dt2.distribution.log_prob(dt2) * loss

        self.assertAllClose(*sess.run(
            [sl_all, sum([loss, dt1_term, dt2_term])]))
        self.assertAllClose(*sess.run([sl_dt1, sum([loss, dt1_term])]))
        self.assertAllClose(*sess.run([sl_dt2, sum([loss, dt2_term])]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号