def testTwoOps(self):
"""Tests that the op can be instantiated twice with appropriate results.
Implementations with inappropriate global registration of gradients will
fail this test.
"""
x = tf.placeholder(tf.float32, [1])
y = x * x
y = snt.scale_gradient(y, 0.1)
y = snt.scale_gradient(y, 0.1)
dydx = tf.gradients([y], [x])[0]
with self.test_session() as sess:
dydx_, y_ = sess.run([dydx, y], feed_dict={x: [3.0]})
self.assertAlmostEqual(dydx_[0], 2 * 0.1**2 * 3.0, places=6)
self.assertAlmostEqual(y_[0], 3.0 ** 2, places=6)
评论列表
文章目录