python类alias()的实例源码

__init__.py 文件源码 项目:ITAP-django 作者: krivers 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def basicTypeSpecialFunction(cv):
    """If you're in a number or string (which has no metadata), move up to the AST to make the special functions work."""
    if isinstance(cv, SwapVector) or isinstance(cv, MoveVector):
        return cv
    if (cv.path[0] in [('n', 'Number'), ('s', 'String'), ('id', 'Name'), ('arg', 'Argument'), 
                        ('value', 'Name Constant'), ('s', 'Bytes'), ('name', 'Alias')]):
        cvCopy = cv.deepcopy()
        cv.oldSubtree = deepcopy(cvCopy.traverseTree(cv.start))
        if cv.path[0] == ('n', 'Number'):
            cv.newSubtree = ast.Num(cv.newSubtree)
        elif cv.path[0] == ('s', 'String'):
            cv.newSubtree = ast.Str(cv.newSubtree)
        elif cv.path[0] == ('id', 'Name'):
            cv.newSubtree = ast.Name(cv.newSubtree, cv.oldSubtree.ctx)
        elif cv.path[0] == ('arg', 'Argument'):
            cv.newSubtree = ast.arg(cv.newSubtree, cv.oldSubtree.annotation)
        elif cv.path[0] == ('value', 'Name Constant'):
            cv.newSubtree = ast.NameConstant(cv.newSubtree)
        elif cv.path[0] == ('s', 'Bytes'):
            cv.newSubtree = ast.Bytes(cv.newSubtree)
        elif cv.path[0] == ('name', 'Alias'):
            cv.newSubtree = ast.alias(cv.newSubtree, cv.oldSubtree.asname)
        cv.path = cv.path[1:]
    return cv
astTools.py 文件源码 项目:ITAP-django 作者: krivers 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def getAllImports(a):
    """Gather all imported module names"""
    if not isinstance(a, ast.AST):
        return []
    imports = []
    for child in ast.walk(a):
        if type(child) == ast.Import:
            for alias in child.names:
                if alias.name in supportedLibraries:
                    imports.append(alias.asname if alias.asname != None else alias.name)
                else:
                    log("astTools\tgetAllImports\tUnknown library: " + alias.name, "bug")
        elif type(child) == ast.ImportFrom:
            if child.module in supportedLibraries:
                for alias in child.names: # these are all functions
                    if alias.name in libraryMap[child.module]:
                        imports.append(alias.asname if alias.asname != None else alias.name)
                    else:
                        log("astTools\tgetAllImports\tUnknown import from name: " + \
                                    child.module + "," + alias.name, "bug")
            else:
                log("astTools\tgetAllImports\tUnknown library: " + child.module, "bug")
    return imports
optimizer.py 文件源码 项目:fatoptimizer 作者: vstinner 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def add_import(tree, name, asname):
    # import fat as __fat__
    import_node = ast.Import(names=[ast.alias(name=name, asname=asname)],
                      lineno=1, col_offset=1)
    for index, node in enumerate(tree.body):
        if (index == 0 and isinstance(node, ast.Expr)
           and isinstance(node.value, ast.Constant)
           and isinstance(node.value.value, str)):
            # docstring
            continue
        if (isinstance(node, ast.ImportFrom) and node.module == '__future__'):
            # from __future__ import ...
            continue
        tree.body.insert(index, import_node)
        break
    else:
        # body is empty or only contains __future__ imports
        tree.body.append(import_node)
modulegraph.py 文件源码 项目:mac-package-build 作者: persepolisdm 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def __init__(self, name, node):
        """
        Initialize this alias.

        Parameters
        ----------
        name : str
            Fully-qualified name of the non-existent target module to be
            created (as an alias of the existing source module).
        node : Node
            Graph node of the existing source module being aliased.
        """
        super(AliasNode, self).__init__(name)

        #FIXME: Why only some? Why not *EVERYTHING* except "graphident", which
        #must remain equal to "name" for lookup purposes? This is, after all,
        #an alias. The idea is for the two nodes to effectively be the same.

        # Copy some attributes from this source module into this target alias.
        for attr_name in (
            'identifier', 'packagepath',
            '_global_attr_names', '_starimported_ignored_module_names',
            '_submodule_basename_to_node'):
            if hasattr(node, attr_name):
                setattr(self, attr_name, getattr(node, attr_name))
diffAsts.py 文件源码 项目:ITAP-django 作者: krivers 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def getChanges(s, t, ignoreVariables=False):
    changes = diffAsts(s, t, ignoreVariables=ignoreVariables)
    for change in changes:
        change.start = s # WARNING: should maybe have a deepcopy here? It will alias s
    return changes
module_definitions.py 文件源码 项目:pyt 作者: python-security 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def __init__(self, local_module_definitions,
                 name, parent_module_name, path):
        self.module_definitions = local_module_definitions
        self.parent_module_name = parent_module_name
        self.path = path

        if parent_module_name:
            if isinstance(parent_module_name, ast.alias):
                self.name = parent_module_name.name + '.' + name
            else:
                self.name = parent_module_name + '.' + name
        else:
            self.name = name
module_definitions.py 文件源码 项目:pyt 作者: python-security 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def __init__(self, import_names=None, module_name=None, is_init=False, filename=None):
        """Optionally set import names and module name.

        Module name should only be set when it is a normal import statement.
        """
        self.import_names = import_names
        # module_name is sometimes ast.alias or a string
        self.module_name = module_name
        self.is_init = is_init
        self.filename = filename
        self.definitions = list()
        self.classes = list()
        self.import_alias_mapping = {}
module_definitions.py 文件源码 项目:pyt 作者: python-security 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def __str__(self):
        module = 'NoModuleName'
        if self.module_name:
            module = self.module_name

        if self.definitions:
            if isinstance(module, ast.alias):
                return (
                    'Definitions: "' + '", "'
                    .join([str(definition) for definition in self.definitions]) +
                    '" and module_name: ' + module.name +
                    ' and filename: ' + str(self.filename) +
                    ' and is_init: ' + str(self.is_init) + '\n')
            return (
                'Definitions: "' + '", "'
                .join([str(definition) for definition in self.definitions]) +
                '" and module_name: ' + module +
                ' and filename: ' + str(self.filename) +
                ' and is_init: ' + str(self.is_init) + '\n')
        else:
            if isinstance(module, ast.alias):
                return (
                    'import_names is '+ str(self.import_names) +
                    ' No Definitions, module_name: ' + str(module.name) +
                    ' and filename: ' + str(self.filename)  +
                    ' and is_init: ' + str(self.is_init) + '\n')
            return (
                'import_names is '+ str(self.import_names) +
                ' No Definitions, module_name: ' + str(module) +
                ' and filename: ' + str(self.filename) +
                ' and is_init: ' + str(self.is_init) + '\n')
test_ast.py 文件源码 项目:zippy 作者: securesystemslab 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def test_bad_integer(self):
        # issue13436: Bad error message with invalid numeric values
        body = [ast.ImportFrom(module='time',
                               names=[ast.alias(name='sleep')],
                               level=None,
                               lineno=None, col_offset=None)]
        mod = ast.Module(body)
        with self.assertRaises((TypeError, ValueError)) as cm:
            compile(mod, 'test', 'exec')
        if support.check_impl_detail():
            self.assertIn("invalid integer value: None", str(cm.exception))
analyzer.py 文件源码 项目:chalice 作者: aws 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def visit_Import(self, node):
        # type: (ast.Import) -> None
        for child in node.names:
            if isinstance(child, ast.alias):
                import_name = child.name
                if import_name == self._SDK_PACKAGE:
                    self._set_inferred_type_for_name(
                        import_name, Boto3ModuleType())
        self.generic_visit(node)
modulegraph.py 文件源码 项目:driveboardapp 作者: nortd 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def _ast_names(names):
    result = []
    for nm in names:
        if isinstance(nm, ast.alias):
            result.append(nm.name)
        else:
            result.append(nm)

    result = [r for r in result if r != '__main__']
    return result
modulegraph.py 文件源码 项目:driveboardapp 作者: nortd 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def alias_module(self, src_module_name, trg_module_name):
        """
        Alias the source module to the target module with the passed names.

        This method ensures that the next call to findNode() given the target
        module name will resolve this alias. This includes importing and adding
        a graph node for the source module if needed as well as adding a
        reference from the target to source module.

        Parameters
        ----------
        src_module_name : str
            Fully-qualified name of the existing **source module** (i.e., the
            module being aliased).
        trg_module_name : str
            Fully-qualified name of the non-existent **target module** (i.e.,
            the alias to be created).
        """
        self.msg(3, 'alias_module "%s" -> "%s"' % (src_module_name, trg_module_name))
        # print('alias_module "%s" -> "%s"' % (src_module_name, trg_module_name))
        assert isinstance(src_module_name, str), '"%s" not a module name.' % str(src_module_name)
        assert isinstance(trg_module_name, str), '"%s" not a module name.' % str(trg_module_name)

        # If the target module has already been added to the graph as either a
        # non-alias or as a different alias, raise an exception.
        trg_module = self.findNode(trg_module_name)
        if trg_module is not None and not (
           isinstance(trg_module, AliasNode) and
           trg_module.identifier == src_module_name):
            raise ValueError('Target module "%s" already imported as "%s".' % (trg_module_name, trg_module))

        # See findNode() for details.
        self.lazynodes[trg_module_name] = Alias(src_module_name)
test_ast.py 文件源码 项目:web_ctp 作者: molebot 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def test_bad_integer(self):
        # issue13436: Bad error message with invalid numeric values
        body = [ast.ImportFrom(module='time',
                               names=[ast.alias(name='sleep')],
                               level=None,
                               lineno=None, col_offset=None)]
        mod = ast.Module(body)
        with self.assertRaises(ValueError) as cm:
            compile(mod, 'test', 'exec')
        self.assertIn("invalid integer value: None", str(cm.exception))
test_ast.py 文件源码 项目:web_ctp 作者: molebot 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_importfrom(self):
        imp = ast.ImportFrom(None, [ast.alias("x", None)], -42)
        self.stmt(imp, "level less than -1")
        self.stmt(ast.ImportFrom(None, [], 0), "empty names on ImportFrom")
flake8_mypy.py 文件源码 项目:flake8-mypy 作者: ambv 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def visit_Import(self, node: ast.Import) -> None:
        for name in node.names:
            if (
                isinstance(name, ast.alias) and
                name.name == 'typing' or
                name.name.startswith('typing.')
            ):
                self.should_type_check = True
                break
test_ast.py 文件源码 项目:ouroboros 作者: pybee 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def test_bad_integer(self):
        # issue13436: Bad error message with invalid numeric values
        body = [ast.ImportFrom(module='time',
                               names=[ast.alias(name='sleep')],
                               level=None,
                               lineno=None, col_offset=None)]
        mod = ast.Module(body)
        with self.assertRaises(ValueError) as cm:
            compile(mod, 'test', 'exec')
        self.assertIn("invalid integer value: None", str(cm.exception))
test_ast.py 文件源码 项目:ouroboros 作者: pybee 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def test_importfrom(self):
        imp = ast.ImportFrom(None, [ast.alias("x", None)], -42)
        self.stmt(imp, "level less than -1")
        self.stmt(ast.ImportFrom(None, [], 0), "empty names on ImportFrom")
pre_analysis.py 文件源码 项目:Typpete 作者: caterinaurban 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def get_all_used_names(self):
        """Get all used variable names and used-defined classes names"""
        names = [node.id for node in self.all_nodes if isinstance(node, ast.Name)]
        names += [node.name for node in self.all_nodes if isinstance(node, ast.ClassDef)]
        names += [node.attr for node in self.all_nodes if isinstance(node, ast.Attribute)]
        names += [node.name for node in self.all_nodes if isinstance(node, ast.alias)]
        return names
test_parse.py 文件源码 项目:pydead 作者: srgypetrov 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def test_visit_import(self, pyfile):
        node = ast.Import(names=[ast.alias(name='Foo', asname='Bar')])
        pyfile.visit_import(node)
        assert pyfile.ast_imported == {'Bar': 'Foo'}
test_parse.py 文件源码 项目:pydead 作者: srgypetrov 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def test_visit_importfrom_invalid(self, pyfile):
        node = ast.ImportFrom(names=[ast.alias(name='*', asname='')],
                              module='apps.app.module', level=0)
        with pytest.raises(SystemExit):
            pyfile.visit_importfrom(node)
test_ast.py 文件源码 项目:kbe_server 作者: xiaohaoppy 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def test_bad_integer(self):
        # issue13436: Bad error message with invalid numeric values
        body = [ast.ImportFrom(module='time',
                               names=[ast.alias(name='sleep')],
                               level=None,
                               lineno=None, col_offset=None)]
        mod = ast.Module(body)
        with self.assertRaises(ValueError) as cm:
            compile(mod, 'test', 'exec')
        self.assertIn("invalid integer value: None", str(cm.exception))
test_ast.py 文件源码 项目:kbe_server 作者: xiaohaoppy 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def test_importfrom(self):
        imp = ast.ImportFrom(None, [ast.alias("x", None)], -42)
        self.stmt(imp, "level less than -1")
        self.stmt(ast.ImportFrom(None, [], 0), "empty names on ImportFrom")
modulegraph.py 文件源码 项目:mac-package-build 作者: persepolisdm 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def _ast_names(names):
    result = []
    for nm in names:
        if isinstance(nm, ast.alias):
            result.append(nm.name)
        else:
            result.append(nm)

    result = [r for r in result if r != '__main__']
    return result
modulegraph.py 文件源码 项目:mac-package-build 作者: persepolisdm 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def alias_module(self, src_module_name, trg_module_name):
        """
        Alias the source module to the target module with the passed names.

        This method ensures that the next call to findNode() given the target
        module name will resolve this alias. This includes importing and adding
        a graph node for the source module if needed as well as adding a
        reference from the target to source module.

        Parameters
        ----------
        src_module_name : str
            Fully-qualified name of the existing **source module** (i.e., the
            module being aliased).
        trg_module_name : str
            Fully-qualified name of the non-existent **target module** (i.e.,
            the alias to be created).
        """
        self.msg(3, 'alias_module "%s" -> "%s"' % (src_module_name, trg_module_name))
        # print('alias_module "%s" -> "%s"' % (src_module_name, trg_module_name))
        assert isinstance(src_module_name, str), '"%s" not a module name.' % str(src_module_name)
        assert isinstance(trg_module_name, str), '"%s" not a module name.' % str(trg_module_name)

        # If the target module has already been added to the graph as either a
        # non-alias or as a different alias, raise an exception.
        trg_module = self.findNode(trg_module_name)
        if trg_module is not None and not (
           isinstance(trg_module, AliasNode) and
           trg_module.identifier == src_module_name):
            raise ValueError('Target module "%s" already imported as "%s".' % (trg_module_name, trg_module))

        # See findNode() for details.
        self.lazynodes[trg_module_name] = Alias(src_module_name)
imports.py 文件源码 项目:pyomo 作者: Pyomo 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def check(self, runner, script, info):
        if isinstance(info, ast.Import):
            for name in info.names:
                if isinstance(name, ast.alias):
                    if name.name == 'pyomo.core':
                        self.pyomoImported = True
                    elif name.name == 'pyomo.environ':
                        self.pyomoImported = True

        if isinstance(info, ast.ImportFrom):
            if info.module == 'pyomo.core':
                self.pyomoImported = True
            elif info.module == 'pyomo.environ':
                self.pyomoImported = True
rewrite.py 文件源码 项目:hostapd-mana 作者: adde88 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def run(self, mod):
        """Find all assert statements in *mod* and rewrite them."""
        if not mod.body:
            # Nothing to do.
            return
        # Insert some special imports at the top of the module but after any
        # docstrings and __future__ imports.
        aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
                   ast.alias("_pytest.assertion.rewrite", "@pytest_ar")]
        expect_docstring = True
        pos = 0
        lineno = 0
        for item in mod.body:
            if (expect_docstring and isinstance(item, ast.Expr) and
                    isinstance(item.value, ast.Str)):
                doc = item.value.s
                if "PYTEST_DONT_REWRITE" in doc:
                    # The module has disabled assertion rewriting.
                    return
                lineno += len(doc) - 1
                expect_docstring = False
            elif (not isinstance(item, ast.ImportFrom) or item.level > 0 or
                  item.module != "__future__"):
                lineno = item.lineno
                break
            pos += 1
        imports = [ast.Import([alias], lineno=lineno, col_offset=0)
                   for alias in aliases]
        mod.body[pos:pos] = imports
        # Collect asserts.
        nodes = [mod]
        while nodes:
            node = nodes.pop()
            for name, field in ast.iter_fields(node):
                if isinstance(field, list):
                    new = []
                    for i, child in enumerate(field):
                        if isinstance(child, ast.Assert):
                            # Transform assert.
                            new.extend(self.visit(child))
                        else:
                            new.append(child)
                            if isinstance(child, ast.AST):
                                nodes.append(child)
                    setattr(node, name, new)
                elif (isinstance(field, ast.AST) and
                      # Don't recurse into expressions as they can't contain
                      # asserts.
                      not isinstance(field, ast.expr)):
                    nodes.append(field)
rewrite.py 文件源码 项目:sslstrip-hsts-openwrt 作者: adde88 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def run(self, mod):
        """Find all assert statements in *mod* and rewrite them."""
        if not mod.body:
            # Nothing to do.
            return
        # Insert some special imports at the top of the module but after any
        # docstrings and __future__ imports.
        aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
                   ast.alias("_pytest.assertion.rewrite", "@pytest_ar")]
        expect_docstring = True
        pos = 0
        lineno = 0
        for item in mod.body:
            if (expect_docstring and isinstance(item, ast.Expr) and
                    isinstance(item.value, ast.Str)):
                doc = item.value.s
                if "PYTEST_DONT_REWRITE" in doc:
                    # The module has disabled assertion rewriting.
                    return
                lineno += len(doc) - 1
                expect_docstring = False
            elif (not isinstance(item, ast.ImportFrom) or item.level > 0 or
                  item.module != "__future__"):
                lineno = item.lineno
                break
            pos += 1
        imports = [ast.Import([alias], lineno=lineno, col_offset=0)
                   for alias in aliases]
        mod.body[pos:pos] = imports
        # Collect asserts.
        nodes = [mod]
        while nodes:
            node = nodes.pop()
            for name, field in ast.iter_fields(node):
                if isinstance(field, list):
                    new = []
                    for i, child in enumerate(field):
                        if isinstance(child, ast.Assert):
                            # Transform assert.
                            new.extend(self.visit(child))
                        else:
                            new.append(child)
                            if isinstance(child, ast.AST):
                                nodes.append(child)
                    setattr(node, name, new)
                elif (isinstance(field, ast.AST) and
                      # Don't recurse into expressions as they can't contain
                      # asserts.
                      not isinstance(field, ast.expr)):
                    nodes.append(field)
rewrite.py 文件源码 项目:godot-python 作者: touilleMan 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def run(self, mod):
        """Find all assert statements in *mod* and rewrite them."""
        if not mod.body:
            # Nothing to do.
            return
        # Insert some special imports at the top of the module but after any
        # docstrings and __future__ imports.
        aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
                   ast.alias("_pytest.assertion.rewrite", "@pytest_ar")]
        expect_docstring = True
        pos = 0
        lineno = 0
        for item in mod.body:
            if (expect_docstring and isinstance(item, ast.Expr) and
                    isinstance(item.value, ast.Str)):
                doc = item.value.s
                if "PYTEST_DONT_REWRITE" in doc:
                    # The module has disabled assertion rewriting.
                    return
                lineno += len(doc) - 1
                expect_docstring = False
            elif (not isinstance(item, ast.ImportFrom) or item.level > 0 or
                  item.module != "__future__"):
                lineno = item.lineno
                break
            pos += 1
        imports = [ast.Import([alias], lineno=lineno, col_offset=0)
                   for alias in aliases]
        mod.body[pos:pos] = imports
        # Collect asserts.
        nodes = [mod]
        while nodes:
            node = nodes.pop()
            for name, field in ast.iter_fields(node):
                if isinstance(field, list):
                    new = []
                    for i, child in enumerate(field):
                        if isinstance(child, ast.Assert):
                            # Transform assert.
                            new.extend(self.visit(child))
                        else:
                            new.append(child)
                            if isinstance(child, ast.AST):
                                nodes.append(child)
                    setattr(node, name, new)
                elif (isinstance(field, ast.AST) and
                      # Don't recurse into expressions as they can't contain
                      # asserts.
                      not isinstance(field, ast.expr)):
                    nodes.append(field)
rewrite.py 文件源码 项目:godot-python 作者: touilleMan 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def run(self, mod):
        """Find all assert statements in *mod* and rewrite them."""
        if not mod.body:
            # Nothing to do.
            return
        # Insert some special imports at the top of the module but after any
        # docstrings and __future__ imports.
        aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
                   ast.alias("_pytest.assertion.rewrite", "@pytest_ar")]
        expect_docstring = True
        pos = 0
        lineno = 0
        for item in mod.body:
            if (expect_docstring and isinstance(item, ast.Expr) and
                    isinstance(item.value, ast.Str)):
                doc = item.value.s
                if "PYTEST_DONT_REWRITE" in doc:
                    # The module has disabled assertion rewriting.
                    return
                lineno += len(doc) - 1
                expect_docstring = False
            elif (not isinstance(item, ast.ImportFrom) or item.level > 0 or
                  item.module != "__future__"):
                lineno = item.lineno
                break
            pos += 1
        imports = [ast.Import([alias], lineno=lineno, col_offset=0)
                   for alias in aliases]
        mod.body[pos:pos] = imports
        # Collect asserts.
        nodes = [mod]
        while nodes:
            node = nodes.pop()
            for name, field in ast.iter_fields(node):
                if isinstance(field, list):
                    new = []
                    for i, child in enumerate(field):
                        if isinstance(child, ast.Assert):
                            # Transform assert.
                            new.extend(self.visit(child))
                        else:
                            new.append(child)
                            if isinstance(child, ast.AST):
                                nodes.append(child)
                    setattr(node, name, new)
                elif (isinstance(field, ast.AST) and
                      # Don't recurse into expressions as they can't contain
                      # asserts.
                      not isinstance(field, ast.expr)):
                    nodes.append(field)
rewrite.py 文件源码 项目:GSM-scanner 作者: yosriayed 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def run(self, mod):
        """Find all assert statements in *mod* and rewrite them."""
        if not mod.body:
            # Nothing to do.
            return
        # Insert some special imports at the top of the module but after any
        # docstrings and __future__ imports.
        aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
                   ast.alias("_pytest.assertion.rewrite", "@pytest_ar")]
        expect_docstring = True
        pos = 0
        lineno = 0
        for item in mod.body:
            if (expect_docstring and isinstance(item, ast.Expr) and
                    isinstance(item.value, ast.Str)):
                doc = item.value.s
                if "PYTEST_DONT_REWRITE" in doc:
                    # The module has disabled assertion rewriting.
                    return
                lineno += len(doc) - 1
                expect_docstring = False
            elif (not isinstance(item, ast.ImportFrom) or item.level > 0 or
                  item.module != "__future__"):
                lineno = item.lineno
                break
            pos += 1
        imports = [ast.Import([alias], lineno=lineno, col_offset=0)
                   for alias in aliases]
        mod.body[pos:pos] = imports
        # Collect asserts.
        nodes = [mod]
        while nodes:
            node = nodes.pop()
            for name, field in ast.iter_fields(node):
                if isinstance(field, list):
                    new = []
                    for i, child in enumerate(field):
                        if isinstance(child, ast.Assert):
                            # Transform assert.
                            new.extend(self.visit(child))
                        else:
                            new.append(child)
                            if isinstance(child, ast.AST):
                                nodes.append(child)
                    setattr(node, name, new)
                elif (isinstance(field, ast.AST) and
                      # Don't recurse into expressions as they can't contain
                      # asserts.
                      not isinstance(field, ast.expr)):
                    nodes.append(field)


问题


面经


文章

微信
公众号

扫码关注公众号