mobilenet_v1_test.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:tf_classification 作者: visipedia 项目源码 文件源码
def testBuildOnlyUptoFinalEndpoint(self):
    batch_size = 5
    height, width = 224, 224
    endpoints = ['Conv2d_0',
                 'Conv2d_1_depthwise', 'Conv2d_1_pointwise',
                 'Conv2d_2_depthwise', 'Conv2d_2_pointwise',
                 'Conv2d_3_depthwise', 'Conv2d_3_pointwise',
                 'Conv2d_4_depthwise', 'Conv2d_4_pointwise',
                 'Conv2d_5_depthwise', 'Conv2d_5_pointwise',
                 'Conv2d_6_depthwise', 'Conv2d_6_pointwise',
                 'Conv2d_7_depthwise', 'Conv2d_7_pointwise',
                 'Conv2d_8_depthwise', 'Conv2d_8_pointwise',
                 'Conv2d_9_depthwise', 'Conv2d_9_pointwise',
                 'Conv2d_10_depthwise', 'Conv2d_10_pointwise',
                 'Conv2d_11_depthwise', 'Conv2d_11_pointwise',
                 'Conv2d_12_depthwise', 'Conv2d_12_pointwise',
                 'Conv2d_13_depthwise', 'Conv2d_13_pointwise']
    for index, endpoint in enumerate(endpoints):
      with tf.Graph().as_default():
        inputs = tf.random_uniform((batch_size, height, width, 3))
        out_tensor, end_points = mobilenet_v1.mobilenet_v1_base(
            inputs, final_endpoint=endpoint)
        self.assertTrue(out_tensor.op.name.startswith(
            'MobilenetV1/' + endpoint))
        self.assertItemsEqual(endpoints[:index+1], end_points)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号