build_vgg.py 文件源码

python
阅读 15 收藏 0 点赞 0 评论 0

项目:tensorflow-litterbox 作者: rwightman 项目源码 文件源码
def _build_vgg16(
        inputs,
        num_classes=1000,
        dropout_keep_prob=0.5,
        is_training=True,
        scope=''):
    """Blah"""

    endpoints = {}
    with tf.name_scope(scope, 'vgg16', [inputs]):
        with arg_scope(
                [layers.batch_norm, layers.dropout], is_training=is_training):
            with arg_scope(
                    [layers.conv2d, layers.max_pool2d], 
                    stride=1,
                    padding='SAME'):

                net = _block_a(inputs, endpoints, d=64, scope='Scale1')
                net = _block_a(net, endpoints, d=128, scope='Scale2')
                net = _block_b(net, endpoints, d=256, scope='Scale3')
                net = _block_b(net, endpoints, d=512, scope='Scale4')
                net = _block_b(net, endpoints, d=512, scope='Scale5')
                logits = _block_output(net, endpoints, num_classes, dropout_keep_prob)

                endpoints['Predictions'] = tf.nn.softmax(logits, name='Predictions')
                return logits, endpoints
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号