resnet_v1_test.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:isbi2017-part3 作者: learningtitans 项目源码 文件源码
def testUnknownBatchSize(self):
    batch = 2
    height, width = 65, 65
    global_pool = True
    num_classes = 10
    inputs = create_test_input(None, height, width, 3)
    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
      logits, _ = self._resnet_small(inputs, num_classes,
                                     global_pool=global_pool,
                                     scope='resnet')
    self.assertTrue(logits.op.name.startswith('resnet/logits'))
    self.assertListEqual(logits.get_shape().as_list(),
                         [None, 1, 1, num_classes])
    images = create_test_input(batch, height, width, 3)
    with self.test_session() as sess:
      sess.run(tf.initialize_all_variables())
      output = sess.run(logits, {inputs: images.eval()})
      self.assertEqual(output.shape, (batch, 1, 1, num_classes))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号