def build_finetunable_model(inputs, cfg):
with slim.arg_scope([slim.conv2d],
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
weights_regularizer=slim.l2_regularizer(0.00004),
biases_regularizer=slim.l2_regularizer(0.00004)) as scope:
batch_norm_params = {
'decay': cfg.BATCHNORM_MOVING_AVERAGE_DECAY,
'epsilon': 0.001,
'variables_collections' : [],
'is_training' : False
}
with slim.arg_scope([slim.conv2d], normalizer_params=batch_norm_params):
features, _ = model.inception_resnet_v2(inputs, reuse=False, scope='InceptionResnetV2')
# Save off the original variables (for ease of restoring)
model_variables = slim.get_model_variables()
inception_vars = {var.op.name:var for var in model_variables}
batch_norm_params = {
'decay': cfg.BATCHNORM_MOVING_AVERAGE_DECAY,
'epsilon': 0.001,
'variables_collections' : [tf.GraphKeys.MOVING_AVERAGE_VARIABLES],
'is_training' : True
}
with slim.arg_scope([slim.conv2d], normalizer_params=batch_norm_params):
# Add on the detection heads
locs, confs, _ = model.build_detection_heads(features, cfg.NUM_BBOXES_PER_CELL)
model_variables = slim.get_model_variables()
detection_vars = {var.op.name:var for var in model_variables if var.op.name not in inception_vars}
return locs, confs, inception_vars, detection_vars
评论列表
文章目录