def testGetBackwardOpsChain(self):
# a -> b -> c
a = tf.placeholder(tf.float32)
b = tf.sqrt(a)
c = tf.square(b)
for n in range(4):
for seed_tensors in permutations([a, b, c], n):
if c in seed_tensors:
truth = [a.op, b.op, c.op]
elif b in seed_tensors:
truth = [a.op, b.op]
elif a in seed_tensors:
truth = [a.op]
else:
truth = []
self.assertEqual(get_backward_ops(seed_tensors), truth)
self.assertEqual(get_backward_ops([c], treat_as_inputs=[b]), [c.op])
self.assertEqual(
get_backward_ops([b, c], treat_as_inputs=[b]), [c.op])
self.assertEqual(
get_backward_ops([a, c], treat_as_inputs=[b]), [a.op, c.op])
评论列表
文章目录