scale_gradient_test.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def testOpScale(self, x_, scale):
    x = tf.placeholder(tf.float32, [1])
    y = x * x
    y = snt.scale_gradient(y, scale)
    dydx = tf.gradients([y], [x])[0]

    if scale == 0.0:
      self.assertEqual(y.op.type, "StopGradient")
      self.assertIs(dydx, None)
    else:
      if scale == 1.0:
        self.assertEqual(y.op.type, "Identity")
      else:
        self.assertEqual(y.op.type, "ScaleGradient_float32")

      with self.test_session() as sess:
        dydx_, y_ = sess.run([dydx, y], feed_dict={x: [x_]})

        self.assertAlmostEqual(dydx_[0], 2 * scale * x_, places=6)
        self.assertAlmostEqual(y_[0], x_ ** 2, places=6)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号