def testTrainableFlagIsPassedOn(self):
for trainable in [True, False]:
with ops.Graph().as_default():
num_filters = 32
input_size = [5, 10, 12, 3]
images = random_ops.random_uniform(input_size, seed=1)
layers_lib.conv2d_transpose(
images, num_filters, [3, 3], stride=1, trainable=trainable)
model_variables = variables.get_model_variables()
trainable_variables = variables_lib.trainable_variables()
for model_variable in model_variables:
self.assertEqual(trainable, model_variable in trainable_variables)
layers_test.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录