def testBuildLogits(self):
batch_size = 5
height, width = 299, 299
num_classes = 1000
inputs = tf.random_uniform((batch_size, height, width, 3))
logits, end_points = inception.inception_v4(inputs, num_classes)
auxlogits = end_points['AuxLogits']
predictions = end_points['Predictions']
self.assertTrue(auxlogits.op.name.startswith('InceptionV4/AuxLogits'))
self.assertListEqual(auxlogits.get_shape().as_list(),
[batch_size, num_classes])
self.assertTrue(logits.op.name.startswith('InceptionV4/Logits'))
self.assertListEqual(logits.get_shape().as_list(),
[batch_size, num_classes])
self.assertTrue(predictions.op.name.startswith(
'InceptionV4/Logits/Predictions'))
self.assertListEqual(predictions.get_shape().as_list(),
[batch_size, num_classes])
评论列表
文章目录