mobilenet_v1_test.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:tf_classification 作者: visipedia 项目源码 文件源码
def testBuildCustomNetworkUsingConvDefs(self):
    batch_size = 5
    height, width = 224, 224
    conv_defs = [
        mobilenet_v1.Conv(kernel=[3, 3], stride=2, depth=32),
        mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=64),
        mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=2, depth=128),
        mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=512)
    ]

    inputs = tf.random_uniform((batch_size, height, width, 3))
    net, end_points = mobilenet_v1.mobilenet_v1_base(
        inputs, final_endpoint='Conv2d_3_pointwise', conv_defs=conv_defs)
    self.assertTrue(net.op.name.startswith('MobilenetV1/Conv2d_3'))
    self.assertListEqual(net.get_shape().as_list(),
                         [batch_size, 56, 56, 512])
    expected_endpoints = ['Conv2d_0',
                          'Conv2d_1_depthwise', 'Conv2d_1_pointwise',
                          'Conv2d_2_depthwise', 'Conv2d_2_pointwise',
                          'Conv2d_3_depthwise', 'Conv2d_3_pointwise']
    self.assertItemsEqual(end_points.keys(), expected_endpoints)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号