def testUnknowBatchSize(self):
batch_size = 1
height, width = 224, 224
num_classes = 1000
inputs = tf.placeholder(tf.float32, (None, height, width, 3))
logits, _ = inception.inception_v2(inputs, num_classes)
self.assertTrue(logits.op.name.startswith('InceptionV2/Logits'))
self.assertListEqual(logits.get_shape().as_list(),
[None, num_classes])
images = tf.random_uniform((batch_size, height, width, 3))
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(logits, {inputs: images.eval()})
self.assertEquals(output.shape, (batch_size, num_classes))
评论列表
文章目录