def extract(self, x, layers=['conv2']):
h = x
activations = []
target_layers = set(layers)
for key, funcs in self.functions.items():
for func in funcs:
h = func(h)
if key in target_layers:
activations.append(h)
target_layers.remove(key)
return activations, h
评论列表
文章目录