def testCustomGetter(self):
"""Check that custom getters work appropriately."""
def custom_getter(getter, *args, **kwargs):
kwargs["trainable"] = False
return getter(*args, **kwargs)
inputs = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size])
# Make w and b non-trainable.
lin1 = snt.Linear(output_size=self.out_size,
custom_getter=custom_getter)
lin1(inputs)
self.assertEqual(0, len(tf.trainable_variables()))
self.assertEqual(2, len(tf.global_variables()))
# Make w non-trainable.
lin2 = snt.Linear(output_size=self.out_size,
custom_getter={"w": custom_getter})
lin2(inputs)
self.assertEqual(1, len(tf.trainable_variables()))
self.assertEqual(4, len(tf.global_variables()))
评论列表
文章目录