layers_test.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def testTrainableFlagIsPassedOn(self):
    for trainable in [True, False]:
      with ops.Graph().as_default():
        num_filters = 32
        input_size = [5, 10, 12, 3]

        images = random_ops.random_uniform(input_size, seed=1)
        layers_lib.conv2d_transpose(
            images, num_filters, [3, 3], stride=1, trainable=trainable)
        model_variables = variables.get_model_variables()
        trainable_variables = variables_lib.trainable_variables()
        for model_variable in model_variables:
          self.assertEqual(trainable, model_variable in trainable_variables)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号