def testGetBackwardOpsSplit(self):
# a -> b -> c
# \-> d
a = tf.placeholder(tf.float32)
b = tf.exp(a)
c = tf.log(b)
d = tf.negative(b)
self.assertEqual(get_backward_ops([d]), [a.op, b.op, d.op])
self.assertEqual(get_backward_ops([c]), [a.op, b.op, c.op])
self.assertEqual(
get_backward_ops([c, d]), [a.op, b.op, c.op, d.op])
self.assertEqual(get_backward_ops([b, d]), [a.op, b.op, d.op])
self.assertEqual(get_backward_ops([a, d]), [a.op, b.op, d.op])
self.assertEqual(
get_backward_ops([c, d], treat_as_inputs=[b]), [c.op, d.op])
self.assertEqual(
get_backward_ops([c], treat_as_inputs=[d]), [a.op, b.op, c.op])
评论列表
文章目录