test_base.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:tfutils 作者: neuroailab 项目源码 文件源码
def get_extraction_target(inputs, outputs, to_extract, **loss_params):
        """Produce validation target function.

        Example validation target function to use to provide targets for extracting features.
        This function also adds a standard "loss" target which you may or not may not want

        The to_extract argument must be a dictionary of the form
              {name_for_saving: name_of_actual_tensor, ...}
        where the "name_for_saving" is a human-friendly name you want to save extracted
        features under, and name_of_actual_tensor is a name of the tensor in the tensorflow
        graph outputing the features desired to be extracted.  To figure out what the names
        of the tensors you want to extract are "to_extract" argument,  uncomment the
        commented-out lines, which will print a list of all available tensor names.

        """
        names = [[x.name for x in op.values()] for op in tf.get_default_graph().get_operations()]
        names = [y for x in names for y in x]

        r = re.compile(r'__GPU__\d/')
        _targets = defaultdict(list)

        for name in names:
            name_without_gpu_prefix = r.sub('', name)
            for save_name, actual_name in to_extract.items():
                if actual_name in name_without_gpu_prefix:
                    tensor = tf.get_default_graph().get_tensor_by_name(name)
                    _targets[save_name].append(tensor)

        targets = {k: tf.concat(v, axis=0) for k, v in _targets.items()}
        targets['loss'] = utils.get_loss(inputs, outputs, **loss_params)
        return targets
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号