def squeezenet(inputs,
num_classes=1000,
is_training=True,
keep_prob=0.5,
spatial_squeeze=True,
scope='squeeze'):
"""
squeezenetv1.1
"""
with tf.name_scope(scope, 'squeeze', [inputs]) as sc:
end_points_collection = sc + '_end_points'
# Collect outputs for conv2d, fully_connected and max_pool2d.
with slim.arg_scope([slim.conv2d, slim.max_pool2d,
slim.avg_pool2d, fire_module],
outputs_collections=end_points_collection):
nets = squeezenet_inference(inputs, is_training, keep_prob)
nets = slim.conv2d(nets, num_classes, [1, 1],
activation_fn=None,
normalizer_fn=None,
scope='logits')
end_points = slim.utils.convert_collection_to_dict(end_points_collection)
if spatial_squeeze:
nets = tf.squeeze(nets, [1, 2], name='logits/squeezed')
return nets, end_points
评论列表
文章目录