inception_resnet_v2_test.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:tf_classification 作者: visipedia 项目源码 文件源码
def testBuildEndPoints(self):
    batch_size = 5
    height, width = 299, 299
    num_classes = 1000
    with self.test_session():
      inputs = tf.random_uniform((batch_size, height, width, 3))
      _, end_points = inception.inception_resnet_v2(inputs, num_classes)
      self.assertTrue('Logits' in end_points)
      logits = end_points['Logits']
      self.assertListEqual(logits.get_shape().as_list(),
                           [batch_size, num_classes])
      self.assertTrue('AuxLogits' in end_points)
      aux_logits = end_points['AuxLogits']
      self.assertListEqual(aux_logits.get_shape().as_list(),
                           [batch_size, num_classes])
      pre_pool = end_points['PrePool']
      self.assertListEqual(pre_pool.get_shape().as_list(),
                           [batch_size, 8, 8, 1536])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号