def basicTypeSpecialFunction(cv):
"""If you're in a number or string (which has no metadata), move up to the AST to make the special functions work."""
if isinstance(cv, SwapVector) or isinstance(cv, MoveVector):
return cv
if (cv.path[0] in [('n', 'Number'), ('s', 'String'), ('id', 'Name'), ('arg', 'Argument'),
('value', 'Name Constant'), ('s', 'Bytes'), ('name', 'Alias')]):
cvCopy = cv.deepcopy()
cv.oldSubtree = deepcopy(cvCopy.traverseTree(cv.start))
if cv.path[0] == ('n', 'Number'):
cv.newSubtree = ast.Num(cv.newSubtree)
elif cv.path[0] == ('s', 'String'):
cv.newSubtree = ast.Str(cv.newSubtree)
elif cv.path[0] == ('id', 'Name'):
cv.newSubtree = ast.Name(cv.newSubtree, cv.oldSubtree.ctx)
elif cv.path[0] == ('arg', 'Argument'):
cv.newSubtree = ast.arg(cv.newSubtree, cv.oldSubtree.annotation)
elif cv.path[0] == ('value', 'Name Constant'):
cv.newSubtree = ast.NameConstant(cv.newSubtree)
elif cv.path[0] == ('s', 'Bytes'):
cv.newSubtree = ast.Bytes(cv.newSubtree)
elif cv.path[0] == ('name', 'Alias'):
cv.newSubtree = ast.alias(cv.newSubtree, cv.oldSubtree.asname)
cv.path = cv.path[1:]
return cv
python类alias()的实例源码
def getAllImports(a):
"""Gather all imported module names"""
if not isinstance(a, ast.AST):
return []
imports = []
for child in ast.walk(a):
if type(child) == ast.Import:
for alias in child.names:
if alias.name in supportedLibraries:
imports.append(alias.asname if alias.asname != None else alias.name)
else:
log("astTools\tgetAllImports\tUnknown library: " + alias.name, "bug")
elif type(child) == ast.ImportFrom:
if child.module in supportedLibraries:
for alias in child.names: # these are all functions
if alias.name in libraryMap[child.module]:
imports.append(alias.asname if alias.asname != None else alias.name)
else:
log("astTools\tgetAllImports\tUnknown import from name: " + \
child.module + "," + alias.name, "bug")
else:
log("astTools\tgetAllImports\tUnknown library: " + child.module, "bug")
return imports
def add_import(tree, name, asname):
# import fat as __fat__
import_node = ast.Import(names=[ast.alias(name=name, asname=asname)],
lineno=1, col_offset=1)
for index, node in enumerate(tree.body):
if (index == 0 and isinstance(node, ast.Expr)
and isinstance(node.value, ast.Constant)
and isinstance(node.value.value, str)):
# docstring
continue
if (isinstance(node, ast.ImportFrom) and node.module == '__future__'):
# from __future__ import ...
continue
tree.body.insert(index, import_node)
break
else:
# body is empty or only contains __future__ imports
tree.body.append(import_node)
def __init__(self, name, node):
"""
Initialize this alias.
Parameters
----------
name : str
Fully-qualified name of the non-existent target module to be
created (as an alias of the existing source module).
node : Node
Graph node of the existing source module being aliased.
"""
super(AliasNode, self).__init__(name)
#FIXME: Why only some? Why not *EVERYTHING* except "graphident", which
#must remain equal to "name" for lookup purposes? This is, after all,
#an alias. The idea is for the two nodes to effectively be the same.
# Copy some attributes from this source module into this target alias.
for attr_name in (
'identifier', 'packagepath',
'_global_attr_names', '_starimported_ignored_module_names',
'_submodule_basename_to_node'):
if hasattr(node, attr_name):
setattr(self, attr_name, getattr(node, attr_name))
def getChanges(s, t, ignoreVariables=False):
changes = diffAsts(s, t, ignoreVariables=ignoreVariables)
for change in changes:
change.start = s # WARNING: should maybe have a deepcopy here? It will alias s
return changes
def __init__(self, local_module_definitions,
name, parent_module_name, path):
self.module_definitions = local_module_definitions
self.parent_module_name = parent_module_name
self.path = path
if parent_module_name:
if isinstance(parent_module_name, ast.alias):
self.name = parent_module_name.name + '.' + name
else:
self.name = parent_module_name + '.' + name
else:
self.name = name
def __init__(self, import_names=None, module_name=None, is_init=False, filename=None):
"""Optionally set import names and module name.
Module name should only be set when it is a normal import statement.
"""
self.import_names = import_names
# module_name is sometimes ast.alias or a string
self.module_name = module_name
self.is_init = is_init
self.filename = filename
self.definitions = list()
self.classes = list()
self.import_alias_mapping = {}
def __str__(self):
module = 'NoModuleName'
if self.module_name:
module = self.module_name
if self.definitions:
if isinstance(module, ast.alias):
return (
'Definitions: "' + '", "'
.join([str(definition) for definition in self.definitions]) +
'" and module_name: ' + module.name +
' and filename: ' + str(self.filename) +
' and is_init: ' + str(self.is_init) + '\n')
return (
'Definitions: "' + '", "'
.join([str(definition) for definition in self.definitions]) +
'" and module_name: ' + module +
' and filename: ' + str(self.filename) +
' and is_init: ' + str(self.is_init) + '\n')
else:
if isinstance(module, ast.alias):
return (
'import_names is '+ str(self.import_names) +
' No Definitions, module_name: ' + str(module.name) +
' and filename: ' + str(self.filename) +
' and is_init: ' + str(self.is_init) + '\n')
return (
'import_names is '+ str(self.import_names) +
' No Definitions, module_name: ' + str(module) +
' and filename: ' + str(self.filename) +
' and is_init: ' + str(self.is_init) + '\n')
def test_bad_integer(self):
# issue13436: Bad error message with invalid numeric values
body = [ast.ImportFrom(module='time',
names=[ast.alias(name='sleep')],
level=None,
lineno=None, col_offset=None)]
mod = ast.Module(body)
with self.assertRaises((TypeError, ValueError)) as cm:
compile(mod, 'test', 'exec')
if support.check_impl_detail():
self.assertIn("invalid integer value: None", str(cm.exception))
def visit_Import(self, node):
# type: (ast.Import) -> None
for child in node.names:
if isinstance(child, ast.alias):
import_name = child.name
if import_name == self._SDK_PACKAGE:
self._set_inferred_type_for_name(
import_name, Boto3ModuleType())
self.generic_visit(node)
def _ast_names(names):
result = []
for nm in names:
if isinstance(nm, ast.alias):
result.append(nm.name)
else:
result.append(nm)
result = [r for r in result if r != '__main__']
return result
def alias_module(self, src_module_name, trg_module_name):
"""
Alias the source module to the target module with the passed names.
This method ensures that the next call to findNode() given the target
module name will resolve this alias. This includes importing and adding
a graph node for the source module if needed as well as adding a
reference from the target to source module.
Parameters
----------
src_module_name : str
Fully-qualified name of the existing **source module** (i.e., the
module being aliased).
trg_module_name : str
Fully-qualified name of the non-existent **target module** (i.e.,
the alias to be created).
"""
self.msg(3, 'alias_module "%s" -> "%s"' % (src_module_name, trg_module_name))
# print('alias_module "%s" -> "%s"' % (src_module_name, trg_module_name))
assert isinstance(src_module_name, str), '"%s" not a module name.' % str(src_module_name)
assert isinstance(trg_module_name, str), '"%s" not a module name.' % str(trg_module_name)
# If the target module has already been added to the graph as either a
# non-alias or as a different alias, raise an exception.
trg_module = self.findNode(trg_module_name)
if trg_module is not None and not (
isinstance(trg_module, AliasNode) and
trg_module.identifier == src_module_name):
raise ValueError('Target module "%s" already imported as "%s".' % (trg_module_name, trg_module))
# See findNode() for details.
self.lazynodes[trg_module_name] = Alias(src_module_name)
def test_bad_integer(self):
# issue13436: Bad error message with invalid numeric values
body = [ast.ImportFrom(module='time',
names=[ast.alias(name='sleep')],
level=None,
lineno=None, col_offset=None)]
mod = ast.Module(body)
with self.assertRaises(ValueError) as cm:
compile(mod, 'test', 'exec')
self.assertIn("invalid integer value: None", str(cm.exception))
def test_importfrom(self):
imp = ast.ImportFrom(None, [ast.alias("x", None)], -42)
self.stmt(imp, "level less than -1")
self.stmt(ast.ImportFrom(None, [], 0), "empty names on ImportFrom")
def visit_Import(self, node: ast.Import) -> None:
for name in node.names:
if (
isinstance(name, ast.alias) and
name.name == 'typing' or
name.name.startswith('typing.')
):
self.should_type_check = True
break
def test_bad_integer(self):
# issue13436: Bad error message with invalid numeric values
body = [ast.ImportFrom(module='time',
names=[ast.alias(name='sleep')],
level=None,
lineno=None, col_offset=None)]
mod = ast.Module(body)
with self.assertRaises(ValueError) as cm:
compile(mod, 'test', 'exec')
self.assertIn("invalid integer value: None", str(cm.exception))
def test_importfrom(self):
imp = ast.ImportFrom(None, [ast.alias("x", None)], -42)
self.stmt(imp, "level less than -1")
self.stmt(ast.ImportFrom(None, [], 0), "empty names on ImportFrom")
def get_all_used_names(self):
"""Get all used variable names and used-defined classes names"""
names = [node.id for node in self.all_nodes if isinstance(node, ast.Name)]
names += [node.name for node in self.all_nodes if isinstance(node, ast.ClassDef)]
names += [node.attr for node in self.all_nodes if isinstance(node, ast.Attribute)]
names += [node.name for node in self.all_nodes if isinstance(node, ast.alias)]
return names
def test_visit_import(self, pyfile):
node = ast.Import(names=[ast.alias(name='Foo', asname='Bar')])
pyfile.visit_import(node)
assert pyfile.ast_imported == {'Bar': 'Foo'}
def test_visit_importfrom_invalid(self, pyfile):
node = ast.ImportFrom(names=[ast.alias(name='*', asname='')],
module='apps.app.module', level=0)
with pytest.raises(SystemExit):
pyfile.visit_importfrom(node)
def test_bad_integer(self):
# issue13436: Bad error message with invalid numeric values
body = [ast.ImportFrom(module='time',
names=[ast.alias(name='sleep')],
level=None,
lineno=None, col_offset=None)]
mod = ast.Module(body)
with self.assertRaises(ValueError) as cm:
compile(mod, 'test', 'exec')
self.assertIn("invalid integer value: None", str(cm.exception))
def test_importfrom(self):
imp = ast.ImportFrom(None, [ast.alias("x", None)], -42)
self.stmt(imp, "level less than -1")
self.stmt(ast.ImportFrom(None, [], 0), "empty names on ImportFrom")
def _ast_names(names):
result = []
for nm in names:
if isinstance(nm, ast.alias):
result.append(nm.name)
else:
result.append(nm)
result = [r for r in result if r != '__main__']
return result
def alias_module(self, src_module_name, trg_module_name):
"""
Alias the source module to the target module with the passed names.
This method ensures that the next call to findNode() given the target
module name will resolve this alias. This includes importing and adding
a graph node for the source module if needed as well as adding a
reference from the target to source module.
Parameters
----------
src_module_name : str
Fully-qualified name of the existing **source module** (i.e., the
module being aliased).
trg_module_name : str
Fully-qualified name of the non-existent **target module** (i.e.,
the alias to be created).
"""
self.msg(3, 'alias_module "%s" -> "%s"' % (src_module_name, trg_module_name))
# print('alias_module "%s" -> "%s"' % (src_module_name, trg_module_name))
assert isinstance(src_module_name, str), '"%s" not a module name.' % str(src_module_name)
assert isinstance(trg_module_name, str), '"%s" not a module name.' % str(trg_module_name)
# If the target module has already been added to the graph as either a
# non-alias or as a different alias, raise an exception.
trg_module = self.findNode(trg_module_name)
if trg_module is not None and not (
isinstance(trg_module, AliasNode) and
trg_module.identifier == src_module_name):
raise ValueError('Target module "%s" already imported as "%s".' % (trg_module_name, trg_module))
# See findNode() for details.
self.lazynodes[trg_module_name] = Alias(src_module_name)
def check(self, runner, script, info):
if isinstance(info, ast.Import):
for name in info.names:
if isinstance(name, ast.alias):
if name.name == 'pyomo.core':
self.pyomoImported = True
elif name.name == 'pyomo.environ':
self.pyomoImported = True
if isinstance(info, ast.ImportFrom):
if info.module == 'pyomo.core':
self.pyomoImported = True
elif info.module == 'pyomo.environ':
self.pyomoImported = True
def run(self, mod):
"""Find all assert statements in *mod* and rewrite them."""
if not mod.body:
# Nothing to do.
return
# Insert some special imports at the top of the module but after any
# docstrings and __future__ imports.
aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
ast.alias("_pytest.assertion.rewrite", "@pytest_ar")]
expect_docstring = True
pos = 0
lineno = 0
for item in mod.body:
if (expect_docstring and isinstance(item, ast.Expr) and
isinstance(item.value, ast.Str)):
doc = item.value.s
if "PYTEST_DONT_REWRITE" in doc:
# The module has disabled assertion rewriting.
return
lineno += len(doc) - 1
expect_docstring = False
elif (not isinstance(item, ast.ImportFrom) or item.level > 0 or
item.module != "__future__"):
lineno = item.lineno
break
pos += 1
imports = [ast.Import([alias], lineno=lineno, col_offset=0)
for alias in aliases]
mod.body[pos:pos] = imports
# Collect asserts.
nodes = [mod]
while nodes:
node = nodes.pop()
for name, field in ast.iter_fields(node):
if isinstance(field, list):
new = []
for i, child in enumerate(field):
if isinstance(child, ast.Assert):
# Transform assert.
new.extend(self.visit(child))
else:
new.append(child)
if isinstance(child, ast.AST):
nodes.append(child)
setattr(node, name, new)
elif (isinstance(field, ast.AST) and
# Don't recurse into expressions as they can't contain
# asserts.
not isinstance(field, ast.expr)):
nodes.append(field)
def run(self, mod):
"""Find all assert statements in *mod* and rewrite them."""
if not mod.body:
# Nothing to do.
return
# Insert some special imports at the top of the module but after any
# docstrings and __future__ imports.
aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
ast.alias("_pytest.assertion.rewrite", "@pytest_ar")]
expect_docstring = True
pos = 0
lineno = 0
for item in mod.body:
if (expect_docstring and isinstance(item, ast.Expr) and
isinstance(item.value, ast.Str)):
doc = item.value.s
if "PYTEST_DONT_REWRITE" in doc:
# The module has disabled assertion rewriting.
return
lineno += len(doc) - 1
expect_docstring = False
elif (not isinstance(item, ast.ImportFrom) or item.level > 0 or
item.module != "__future__"):
lineno = item.lineno
break
pos += 1
imports = [ast.Import([alias], lineno=lineno, col_offset=0)
for alias in aliases]
mod.body[pos:pos] = imports
# Collect asserts.
nodes = [mod]
while nodes:
node = nodes.pop()
for name, field in ast.iter_fields(node):
if isinstance(field, list):
new = []
for i, child in enumerate(field):
if isinstance(child, ast.Assert):
# Transform assert.
new.extend(self.visit(child))
else:
new.append(child)
if isinstance(child, ast.AST):
nodes.append(child)
setattr(node, name, new)
elif (isinstance(field, ast.AST) and
# Don't recurse into expressions as they can't contain
# asserts.
not isinstance(field, ast.expr)):
nodes.append(field)
def run(self, mod):
"""Find all assert statements in *mod* and rewrite them."""
if not mod.body:
# Nothing to do.
return
# Insert some special imports at the top of the module but after any
# docstrings and __future__ imports.
aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
ast.alias("_pytest.assertion.rewrite", "@pytest_ar")]
expect_docstring = True
pos = 0
lineno = 0
for item in mod.body:
if (expect_docstring and isinstance(item, ast.Expr) and
isinstance(item.value, ast.Str)):
doc = item.value.s
if "PYTEST_DONT_REWRITE" in doc:
# The module has disabled assertion rewriting.
return
lineno += len(doc) - 1
expect_docstring = False
elif (not isinstance(item, ast.ImportFrom) or item.level > 0 or
item.module != "__future__"):
lineno = item.lineno
break
pos += 1
imports = [ast.Import([alias], lineno=lineno, col_offset=0)
for alias in aliases]
mod.body[pos:pos] = imports
# Collect asserts.
nodes = [mod]
while nodes:
node = nodes.pop()
for name, field in ast.iter_fields(node):
if isinstance(field, list):
new = []
for i, child in enumerate(field):
if isinstance(child, ast.Assert):
# Transform assert.
new.extend(self.visit(child))
else:
new.append(child)
if isinstance(child, ast.AST):
nodes.append(child)
setattr(node, name, new)
elif (isinstance(field, ast.AST) and
# Don't recurse into expressions as they can't contain
# asserts.
not isinstance(field, ast.expr)):
nodes.append(field)
def run(self, mod):
"""Find all assert statements in *mod* and rewrite them."""
if not mod.body:
# Nothing to do.
return
# Insert some special imports at the top of the module but after any
# docstrings and __future__ imports.
aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
ast.alias("_pytest.assertion.rewrite", "@pytest_ar")]
expect_docstring = True
pos = 0
lineno = 0
for item in mod.body:
if (expect_docstring and isinstance(item, ast.Expr) and
isinstance(item.value, ast.Str)):
doc = item.value.s
if "PYTEST_DONT_REWRITE" in doc:
# The module has disabled assertion rewriting.
return
lineno += len(doc) - 1
expect_docstring = False
elif (not isinstance(item, ast.ImportFrom) or item.level > 0 or
item.module != "__future__"):
lineno = item.lineno
break
pos += 1
imports = [ast.Import([alias], lineno=lineno, col_offset=0)
for alias in aliases]
mod.body[pos:pos] = imports
# Collect asserts.
nodes = [mod]
while nodes:
node = nodes.pop()
for name, field in ast.iter_fields(node):
if isinstance(field, list):
new = []
for i, child in enumerate(field):
if isinstance(child, ast.Assert):
# Transform assert.
new.extend(self.visit(child))
else:
new.append(child)
if isinstance(child, ast.AST):
nodes.append(child)
setattr(node, name, new)
elif (isinstance(field, ast.AST) and
# Don't recurse into expressions as they can't contain
# asserts.
not isinstance(field, ast.expr)):
nodes.append(field)
def run(self, mod):
"""Find all assert statements in *mod* and rewrite them."""
if not mod.body:
# Nothing to do.
return
# Insert some special imports at the top of the module but after any
# docstrings and __future__ imports.
aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
ast.alias("_pytest.assertion.rewrite", "@pytest_ar")]
expect_docstring = True
pos = 0
lineno = 0
for item in mod.body:
if (expect_docstring and isinstance(item, ast.Expr) and
isinstance(item.value, ast.Str)):
doc = item.value.s
if "PYTEST_DONT_REWRITE" in doc:
# The module has disabled assertion rewriting.
return
lineno += len(doc) - 1
expect_docstring = False
elif (not isinstance(item, ast.ImportFrom) or item.level > 0 or
item.module != "__future__"):
lineno = item.lineno
break
pos += 1
imports = [ast.Import([alias], lineno=lineno, col_offset=0)
for alias in aliases]
mod.body[pos:pos] = imports
# Collect asserts.
nodes = [mod]
while nodes:
node = nodes.pop()
for name, field in ast.iter_fields(node):
if isinstance(field, list):
new = []
for i, child in enumerate(field):
if isinstance(child, ast.Assert):
# Transform assert.
new.extend(self.visit(child))
else:
new.append(child)
if isinstance(child, ast.AST):
nodes.append(child)
setattr(node, name, new)
elif (isinstance(field, ast.AST) and
# Don't recurse into expressions as they can't contain
# asserts.
not isinstance(field, ast.expr)):
nodes.append(field)