basic_test.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def testCustomGetter(self):
    """Check that custom getters work appropriately."""

    def custom_getter(getter, *args, **kwargs):
      kwargs["trainable"] = False
      return getter(*args, **kwargs)

    inputs = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size])

    # Make w and b non-trainable.
    lin1 = snt.Linear(output_size=self.out_size,
                      custom_getter=custom_getter)
    lin1(inputs)
    self.assertEqual(0, len(tf.trainable_variables()))
    self.assertEqual(2, len(tf.global_variables()))

    # Make w non-trainable.
    lin2 = snt.Linear(output_size=self.out_size,
                      custom_getter={"w": custom_getter})
    lin2(inputs)
    self.assertEqual(1, len(tf.trainable_variables()))
    self.assertEqual(4, len(tf.global_variables()))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号