def testExplicitStochasticTensors(self):
with self.test_session() as sess:
mu = constant_op.constant([0.0, 0.1, 0.2])
sigma = constant_op.constant([1.1, 1.2, 1.3])
with st.value_type(st.SampleValue()):
dt1 = st.StochasticTensor(NormalNotParam(loc=mu, scale=sigma))
dt2 = st.StochasticTensor(NormalNotParam(loc=mu, scale=sigma))
loss = math_ops.square(array_ops.identity(dt1)) + 10. + dt2
sl_all = sg.surrogate_loss([loss])
sl_dt1 = sg.surrogate_loss([loss], stochastic_tensors=[dt1])
sl_dt2 = sg.surrogate_loss([loss], stochastic_tensors=[dt2])
dt1_term = dt1.distribution.log_prob(dt1) * loss
dt2_term = dt2.distribution.log_prob(dt2) * loss
self.assertAllClose(*sess.run(
[sl_all, sum([loss, dt1_term, dt2_term])]))
self.assertAllClose(*sess.run([sl_dt1, sum([loss, dt1_term])]))
self.assertAllClose(*sess.run([sl_dt2, sum([loss, dt2_term])]))
stochastic_graph_test.py 文件源码
python
阅读 20
收藏 0
点赞 0
评论 0
评论列表
文章目录