def testBuildEndPoints(self):
batch_size = 5
height, width = 299, 299
num_classes = 1000
with self.test_session():
inputs = tf.random_uniform((batch_size, height, width, 3))
_, end_points = inception.inception_resnet_v2(inputs, num_classes)
self.assertTrue('Logits' in end_points)
logits = end_points['Logits']
self.assertListEqual(logits.get_shape().as_list(),
[batch_size, num_classes])
self.assertTrue('AuxLogits' in end_points)
aux_logits = end_points['AuxLogits']
self.assertListEqual(aux_logits.get_shape().as_list(),
[batch_size, num_classes])
pre_pool = end_points['PrePool']
self.assertListEqual(pre_pool.get_shape().as_list(),
[batch_size, 8, 8, 1536])
inception_resnet_v2_test.py 文件源码
python
阅读 34
收藏 0
点赞 0
评论 0
评论列表
文章目录