test_utils.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def testGetBackwardOpsChain(self):
        # a -> b -> c
        a = tf.placeholder(tf.float32)
        b = tf.sqrt(a)
        c = tf.square(b)
        for n in range(4):
            for seed_tensors in permutations([a, b, c], n):
                if c in seed_tensors:
                    truth = [a.op, b.op, c.op]
                elif b in seed_tensors:
                    truth = [a.op, b.op]
                elif a in seed_tensors:
                    truth = [a.op]
                else:
                    truth = []
                self.assertEqual(get_backward_ops(seed_tensors), truth)

        self.assertEqual(get_backward_ops([c], treat_as_inputs=[b]), [c.op])
        self.assertEqual(
            get_backward_ops([b, c], treat_as_inputs=[b]), [c.op])
        self.assertEqual(
            get_backward_ops([a, c], treat_as_inputs=[b]), [a.op, c.op])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号