def maps_pred_fun(checkpoint):
# Load model
model = load_model(checkpoint)
x = model.input
# Get feature maps before GAP
o = [l for l in model.layers if type(l) == GlobalAveragePooling2D][-1].input
# Setup CAM
dense_list = [l for l in model.layers if type(l) == Dense]
num_dense = len(dense_list)
if num_dense > 1:
raise ValueError('Expected only one dense layer, found %d' %num_dense)
# If there is no dense layer after (NiN), the maps are already class maps
if num_dense: # Apply CAM if there is a dense layer
dense_layer = dense_list[0]
# Get dense layer weights
W = K.get_value(dense_layer.W)[None, None] # (1, 1, ?, ?)
b = K.get_value(dense_layer.b)
# Transform it into a 1x1 conv
# This convolution will map the feature maps into class 'heatmaps'
o = Convolution2D(W.shape[-1], 1, 1, border_mode='valid', weights=[W, b])(o)
# Resize with bilinear method
maps = tf.image.resize_images(o, K.shape(x)[1:3], method=tf.image.ResizeMethod.BILINEAR)
return K.function([x, K.learning_phase()], [maps, model.output])
评论列表
文章目录