test.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:tfutils 作者: neuroailab 项目源码 文件源码
def get_extraction_target(inputs, outputs, to_extract, **loss_params):
    """Produce validation target function.

    Example validation target function to use to provide targets for extracting features.
    This function also adds a standard "loss" target which you may or not may not want

    The to_extract argument must be a dictionary of the form
          {name_for_saving: name_of_actual_tensor, ...}
    where the "name_for_saving" is a human-friendly name you want to save extracted
    features under, and name_of_actual_tensor is a name of the tensor in the tensorflow
    graph outputing the features desired to be extracted.  To figure out what the names
    of the tensors you want to extract are "to_extract" argument,  uncomment the
    commented-out lines, which will print a list of all available tensor names.

    """
    names = [[x.name for x in op.values()] for op in tf.get_default_graph().get_operations()]
    names = [y for x in names for y in x]

    r = re.compile(r'__GPU__\d/')
    _targets = defaultdict(list)

    for name in names:
        name_without_gpu_prefix = r.sub('', name)
        for save_name, actual_name in to_extract.items():
            if actual_name in name_without_gpu_prefix:
                tensor = tf.get_default_graph().get_tensor_by_name(name)
                _targets[save_name].append(tensor)

    targets = {k: tf.concat(v, axis=0) for k, v in _targets.items()}
    targets['loss'] = utils.get_loss(inputs, outputs, **loss_params)
    return targets
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号