test_utils.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def testGetBackwardOpsSplit(self):
        # a -> b -> c
        #       \-> d
        a = tf.placeholder(tf.float32)
        b = tf.exp(a)
        c = tf.log(b)
        d = tf.negative(b)
        self.assertEqual(get_backward_ops([d]), [a.op, b.op, d.op])
        self.assertEqual(get_backward_ops([c]), [a.op, b.op, c.op])
        self.assertEqual(
            get_backward_ops([c, d]), [a.op, b.op, c.op, d.op])
        self.assertEqual(get_backward_ops([b, d]), [a.op, b.op, d.op])
        self.assertEqual(get_backward_ops([a, d]), [a.op, b.op, d.op])

        self.assertEqual(
            get_backward_ops([c, d], treat_as_inputs=[b]), [c.op, d.op])
        self.assertEqual(
            get_backward_ops([c], treat_as_inputs=[d]), [a.op, b.op, c.op])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号