test_stochastic.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def test_Normal(self):
        with BayesianNet():
            mean = tf.zeros([2, 3])
            logstd = tf.zeros([2, 3])
            std = tf.exp(logstd)
            n_samples = tf.placeholder(tf.int32, shape=[])
            group_ndims = tf.placeholder(tf.int32, shape=[])
            a = Normal('a', mean, logstd=logstd, n_samples=n_samples,
                       group_ndims=group_ndims)
            b = Normal('b', mean, std=std, n_samples=n_samples,
                       group_ndims=group_ndims)

        for st in [a, b]:
            sample_ops = set(get_backward_ops(st.tensor))
            for i in [mean, logstd, n_samples]:
                self.assertTrue(i.op in sample_ops)
            log_p = st.log_prob(np.ones([2, 3]))
            log_p_ops = set(get_backward_ops(log_p))
            for i in [mean, logstd, group_ndims]:
                self.assertTrue(i.op in log_p_ops)
            self.assertTrue(a.get_shape()[1:], mean.get_shape())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号