def get_extraction_target(inputs, outputs, to_extract, **loss_params):
"""Produce validation target function.
Example validation target function to use to provide targets for extracting features.
This function also adds a standard "loss" target which you may or not may not want
The to_extract argument must be a dictionary of the form
{name_for_saving: name_of_actual_tensor, ...}
where the "name_for_saving" is a human-friendly name you want to save extracted
features under, and name_of_actual_tensor is a name of the tensor in the tensorflow
graph outputing the features desired to be extracted. To figure out what the names
of the tensors you want to extract are "to_extract" argument, uncomment the
commented-out lines, which will print a list of all available tensor names.
"""
names = [[x.name for x in op.values()] for op in tf.get_default_graph().get_operations()]
names = [y for x in names for y in x]
r = re.compile(r'__GPU__\d/')
_targets = defaultdict(list)
for name in names:
name_without_gpu_prefix = r.sub('', name)
for save_name, actual_name in to_extract.items():
if actual_name in name_without_gpu_prefix:
tensor = tf.get_default_graph().get_tensor_by_name(name)
_targets[save_name].append(tensor)
targets = {k: tf.concat(v, axis=0) for k, v in _targets.items()}
targets['loss'] = utils.get_loss(inputs, outputs, **loss_params)
return targets
评论列表
文章目录