stop_gradient_test.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def testUsage(self):
    with tf.variable_scope("", custom_getter=snt.custom_getters.stop_gradient):
      lin1 = snt.Linear(10, name="linear1")

    x = tf.placeholder(tf.float32, [10, 10])
    y = lin1(x)

    variables = tf.trainable_variables()
    variable_names = [v.name for v in variables]

    self.assertEqual(2, len(variables))

    self.assertIn("linear1/w:0", variable_names)
    self.assertIn("linear1/b:0", variable_names)

    grads = tf.gradients(y, variables)

    names_to_grads = {var.name: grad for var, grad in zip(variables, grads)}

    self.assertEqual(None, names_to_grads["linear1/w:0"])
    self.assertEqual(None, names_to_grads["linear1/b:0"])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号