def testStochasticVariables(self):
shape = (10, 20)
with variable_scope.variable_scope(
"stochastic_variables",
custom_getter=sv.make_stochastic_variable_getter(
dist_cls=dist.NormalWithSoftplusScale)):
v = variable_scope.get_variable("sv", shape)
self.assertTrue(isinstance(v, st.StochasticTensor))
self.assertTrue(isinstance(v.distribution, dist.NormalWithSoftplusScale))
self.assertEqual(
{"stochastic_variables/sv_loc", "stochastic_variables/sv_scale"},
set([v.op.name for v in variables.global_variables()]))
self.assertEqual(
set(variables.trainable_variables()), set(variables.global_variables()))
v = ops.convert_to_tensor(v)
self.assertEqual(list(shape), v.get_shape().as_list())
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertEqual(shape, sess.run(v).shape)
stochastic_variables_test.py 文件源码
python
阅读 20
收藏 0
点赞 0
评论 0
评论列表
文章目录