def testUsage(self):
with tf.variable_scope("", custom_getter=snt.custom_getters.stop_gradient):
lin1 = snt.Linear(10, name="linear1")
x = tf.placeholder(tf.float32, [10, 10])
y = lin1(x)
variables = tf.trainable_variables()
variable_names = [v.name for v in variables]
self.assertEqual(2, len(variables))
self.assertIn("linear1/w:0", variable_names)
self.assertIn("linear1/b:0", variable_names)
grads = tf.gradients(y, variables)
names_to_grads = {var.name: grad for var, grad in zip(variables, grads)}
self.assertEqual(None, names_to_grads["linear1/w:0"])
self.assertEqual(None, names_to_grads["linear1/b:0"])
评论列表
文章目录