def testOpScale(self, x_, scale):
x = tf.placeholder(tf.float32, [1])
y = x * x
y = snt.scale_gradient(y, scale)
dydx = tf.gradients([y], [x])[0]
if scale == 0.0:
self.assertEqual(y.op.type, "StopGradient")
self.assertIs(dydx, None)
else:
if scale == 1.0:
self.assertEqual(y.op.type, "Identity")
else:
self.assertEqual(y.op.type, "ScaleGradient_float32")
with self.test_session() as sess:
dydx_, y_ = sess.run([dydx, y], feed_dict={x: [x_]})
self.assertAlmostEqual(dydx_[0], 2 * scale * x_, places=6)
self.assertAlmostEqual(y_[0], x_ ** 2, places=6)
评论列表
文章目录