def get_local_vars(source, namespace):
# local_vars = sys._getframe(depth).f_locals
local_vars_names = set(namespace.keys())
root = ast.parse(source)
required_vars_names = set()
for node in ast.walk(root):
if isinstance(node, ast.Name):
required_vars_names.add(node.id)
builtin_vars_names = set(vars(builtins).keys())
required_local_vars = required_vars_names & local_vars_names
# we might want to add a compiler-ish thing in the future
params = {}
for v in required_local_vars:
params[v] = namespace[v]
return params
python类walk()的实例源码
def search(func, depth=1):
local_vars = sys._getframe(depth).f_locals
source = get_source_code(func)
tree = ast.parse(source)
child_funcs = []
for node in ast.walk(tree):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
child_funcs.append(node.func.id)
elif (isinstance(node, ast.Name) and node.id in local_vars and callable(local_vars[node.id]) and node.id not in sys.builtin_module_names):
child_funcs.append(node.id)
child_load_str = ''
for child in child_funcs:
if child in local_vars:
try:
load_string = search(local_vars[child], depth=(depth + 1))
child_load_str += load_string + '\n'
except Exception as e:
pass
load_str = child_load_str + source
return load_str
def get_statement_startend2(lineno, node):
import ast
# flatten all statements and except handlers into one lineno-list
# AST's line numbers start indexing at 1
l = []
for x in ast.walk(node):
if isinstance(x, _ast.stmt) or isinstance(x, _ast.ExceptHandler):
l.append(x.lineno - 1)
for name in "finalbody", "orelse":
val = getattr(x, name, None)
if val:
# treat the finally/orelse part as its own statement
l.append(val[0].lineno - 1 - 1)
l.sort()
insert_index = bisect_right(l, lineno)
start = l[insert_index - 1]
if insert_index >= len(l):
end = None
else:
end = l[insert_index]
return start, end
def get_statement_startend2(lineno, node):
import ast
# flatten all statements and except handlers into one lineno-list
# AST's line numbers start indexing at 1
l = []
for x in ast.walk(node):
if isinstance(x, _ast.stmt) or isinstance(x, _ast.ExceptHandler):
l.append(x.lineno - 1)
for name in "finalbody", "orelse":
val = getattr(x, name, None)
if val:
# treat the finally/orelse part as its own statement
l.append(val[0].lineno - 1 - 1)
l.sort()
insert_index = bisect_right(l, lineno)
start = l[insert_index - 1]
if insert_index >= len(l):
end = None
else:
end = l[insert_index]
return start, end
def replaceHazards(a):
if not isinstance(a, ast.AST):
return
for field in ast.walk(a):
if type(a) == ast.Import:
for i in range(len(a.names)):
if a.names[i].name not in supportedLibraries:
if not (a.names[i].name[0] == "r" and a.names[i].name[1] in "0123456789") and not ("NotAllowed" in a.names[i].name):
a.names[i].name = a.names[i].name + "NotAllowed"
elif type(a) == ast.ImportFrom:
if a.module not in supportedLibraries:
if not (a.module[0] == "r" and a.module[1] in "0123456789") and not ("NotAllowed" in a.module):
a.module = a.module + "NotAllowed"
elif type(a) == ast.Call:
if type(a.func) == ast.Name and a.func.id in ["compile", "eval", "execfile", "file", "open", "__import__", "apply"]:
a.func.id = a.func.id + "NotAllowed"
def gatherAllNames(a, keep_orig=True):
"""Gather all names in the tree (variable or otherwise).
Names are returned along with their original names
(which are used in variable mapping)"""
if type(a) == list:
allIds = set()
for line in a:
allIds |= gatherAllNames(line)
return allIds
if not isinstance(a, ast.AST):
return set()
allIds = set()
for node in ast.walk(a):
if type(node) == ast.Name:
origName = node.originalId if (keep_orig and hasattr(node, "originalId")) else None
allIds |= set([(node.id, origName)])
return allIds
def gatherAllParameters(a, keep_orig=True):
"""Gather all parameters in the tree. Names are returned along
with their original names (which are used in variable mapping)"""
if type(a) == list:
allIds = set()
for line in a:
allIds |= gatherAllVariables(line)
return allIds
if not isinstance(a, ast.AST):
return set()
allIds = set()
for node in ast.walk(a):
if type(node) == ast.arg:
origName = node.originalId if (keep_orig and hasattr(node, "originalId")) else None
allIds |= set([(node.arg, origName)])
return allIds
def getAllImports(a):
"""Gather all imported module names"""
if not isinstance(a, ast.AST):
return []
imports = []
for child in ast.walk(a):
if type(child) == ast.Import:
for alias in child.names:
if alias.name in supportedLibraries:
imports.append(alias.asname if alias.asname != None else alias.name)
else:
log("astTools\tgetAllImports\tUnknown library: " + alias.name, "bug")
elif type(child) == ast.ImportFrom:
if child.module in supportedLibraries:
for alias in child.names: # these are all functions
if alias.name in libraryMap[child.module]:
imports.append(alias.asname if alias.asname != None else alias.name)
else:
log("astTools\tgetAllImports\tUnknown import from name: " + \
child.module + "," + alias.name, "bug")
else:
log("astTools\tgetAllImports\tUnknown library: " + child.module, "bug")
return imports
def get_version():
with open(os.path.join('settei', 'version.py')) as f:
tree = ast.parse(f.read(), f.name)
for node in ast.walk(tree):
if not (isinstance(node, ast.Assign) and len(node.targets) == 1):
continue
target, = node.targets
value = node.value
if not (isinstance(target, ast.Name) and
target.id == 'VERSION_INFO' and
isinstance(value, ast.Tuple)):
continue
elts = value.elts
if any(not isinstance(elt, ast.Num) for elt in elts):
continue
return '.'.join(str(elt.n) for elt in elts)
def process(fl, external, genfiles, vendor):
src = open(fl).read()
tree = ast.parse(src, fl)
lst = []
wksp = WORKSPACE(external, genfiles, vendor)
for stmt in ast.walk(tree):
stmttype = type(stmt)
if stmttype == ast.Call:
fn = getattr(wksp, stmt.func.id, "")
if not callable(fn):
continue
path, name = keywords(stmt)
if path.endswith(".git"):
path = path[:-4]
path = pathmap.get(path, path)
tup = fn(name, path)
lst.append(tup)
return lst
def walk_python_files():
u'''
Generator that yields all CKAN Python source files.
Yields 2-tuples containing the filename in absolute and relative (to
the project root) form.
'''
def _is_dir_ignored(root, d):
if d.startswith(u'.'):
return True
return os.path.join(rel_root, d) in IGNORED_DIRS
for abs_root, dirnames, filenames in os.walk(PROJECT_ROOT):
rel_root = os.path.relpath(abs_root, PROJECT_ROOT)
if rel_root == u'.':
rel_root = u''
dirnames[:] = [d for d in dirnames if not _is_dir_ignored(rel_root, d)]
for filename in filenames:
if not filename.endswith(u'.py'):
continue
abs_name = os.path.join(abs_root, filename)
rel_name = os.path.join(rel_root, filename)
yield abs_name, rel_name
def run(self):
tree = self.tree
if self.filename == 'stdin':
lines = stdin_utils.stdin_get_value()
tree = ast.parse(lines)
for statement in ast.walk(tree):
for child in ast.iter_child_nodes(statement):
child.__flake8_builtins_parent = statement
for statement in ast.walk(tree):
value = None
if isinstance(statement, ast.Assign):
value = self.check_assignment(statement)
elif isinstance(statement, ast.FunctionDef):
value = self.check_function_definition(statement)
if value:
for line, offset, msg, rtype in value:
yield line, offset, msg, rtype
def walk(node):
"""
Recursively yield all descendant nodes in the tree starting at ``node`` (including ``node``
itself), using depth-first pre-order traversal (yieling parents before their children).
This is similar to ``ast.walk()``, but with a different order, and it works for both ``ast`` and
``astroid`` trees. Also, as ``iter_children()``, it skips singleton nodes generated by ``ast``.
"""
iter_children = iter_children_func(node)
done = set()
stack = [node]
while stack:
current = stack.pop()
assert current not in done # protect againt infinite loop in case of a bad tree.
done.add(current)
yield current
# Insert all children in reverse order (so that first child ends up on top of the stack).
# This is faster than building a list and reversing it.
ins = len(stack)
for c in iter_children(current):
stack.insert(ins, c)
def print_timing(self):
# pylint: disable=no-self-use
# Test the implementation of asttokens.util.walk, which uses the same approach as
# visit_tree(). This doesn't run as a normal unittest, but if you'd like to see timings, e.g.
# after experimenting with the implementation, run this to see them:
#
# nosetests -i print_timing -s tests.test_util
#
import timeit
import textwrap
setup = textwrap.dedent(
'''
import ast, asttokens
source = "foo(bar(1 + 2), 'hello' + ', ' + 'world')"
atok = asttokens.ASTTokens(source, parse=True)
''')
print("ast", sorted(timeit.repeat(
setup=setup, number=10000,
stmt='len(list(ast.walk(atok.tree)))')))
print("util", sorted(timeit.repeat(
setup=setup, number=10000,
stmt='len(list(asttokens.util.walk(atok.tree)))')))
def test_walk_ast(self):
atok = asttokens.ASTTokens(self.source, parse=True)
def view(node):
return "%s:%s" % (node.__class__.__name__, atok.get_text(node))
scan = [view(n) for n in asttokens.util.walk(atok.tree)]
self.assertEqual(scan, [
"Module:foo(bar(1 + 2), 'hello' + ', ' + 'world')",
"Expr:foo(bar(1 + 2), 'hello' + ', ' + 'world')",
"Call:foo(bar(1 + 2), 'hello' + ', ' + 'world')",
'Name:foo',
'Call:bar(1 + 2)',
'Name:bar',
'BinOp:1 + 2',
'Num:1',
'Num:2',
"BinOp:'hello' + ', ' + 'world'",
"BinOp:'hello' + ', '",
"Str:'hello'",
"Str:', '",
"Str:'world'"
])
def test_walk_astroid(self):
atok = asttokens.ASTTokens(self.source, tree=astroid.builder.parse(self.source))
def view(node):
return "%s:%s" % (node.__class__.__name__, atok.get_text(node))
scan = [view(n) for n in asttokens.util.walk(atok.tree)]
self.assertEqual(scan, [
"Module:foo(bar(1 + 2), 'hello' + ', ' + 'world')",
"Expr:foo(bar(1 + 2), 'hello' + ', ' + 'world')",
"Call:foo(bar(1 + 2), 'hello' + ', ' + 'world')",
'Name:foo',
'Call:bar(1 + 2)',
'Name:bar',
'BinOp:1 + 2',
'Const:1',
'Const:2',
"BinOp:'hello' + ', ' + 'world'",
"BinOp:'hello' + ', '",
"Const:'hello'",
"Const:', '",
"Const:'world'"
])
def test_replace(self):
self.assertEqual(asttokens.util.replace("this is a test", [(0, 4, "X"), (8, 9, "THE")]),
"X is THE test")
self.assertEqual(asttokens.util.replace("this is a test", []), "this is a test")
self.assertEqual(asttokens.util.replace("this is a test", [(7,7," NOT")]), "this is NOT a test")
source = "foo(bar(1 + 2), 'hello' + ', ' + 'world')"
atok = asttokens.ASTTokens(source, parse=True)
names = [n for n in asttokens.util.walk(atok.tree) if isinstance(n, ast.Name)]
strings = [n for n in asttokens.util.walk(atok.tree) if isinstance(n, ast.Str)]
repl1 = [atok.get_text_range(n) + ('TEST',) for n in names]
repl2 = [atok.get_text_range(n) + ('val',) for n in strings]
self.assertEqual(asttokens.util.replace(source, repl1 + repl2),
"TEST(TEST(1 + 2), val + val + val)")
self.assertEqual(asttokens.util.replace(source, repl2 + repl1),
"TEST(TEST(1 + 2), val + val + val)")
def _fine_property_definition(self, property_name):
"""Find the lines in the source code that contain this property's name and definition.
This function can find both attribute assignments as well as methods/functions.
Args:
property_name (str): the name of the property to look up in the template definition
Returns:
tuple: line numbers for the start and end of the attribute definition
"""
for node in ast.walk(ast.parse(self._source)):
if isinstance(node, ast.Assign) and node.targets[0].id == property_name:
return node.targets[0].lineno - 1, self._get_node_line_end(node)
elif isinstance(node, ast.FunctionDef) and node.name == property_name:
return node.lineno - 1, self._get_node_line_end(node)
raise ValueError('The requested node could not be found.')
def _find_non_builtin_globals(source, codeobj):
try:
import ast
except ImportError:
return None
try:
import __builtin__
except ImportError:
import builtins as __builtin__
vars = dict.fromkeys(codeobj.co_varnames)
return [
node.id for node in ast.walk(ast.parse(source))
if isinstance(node, ast.Name) and
node.id not in vars and
node.id not in __builtin__.__dict__
]
def get_statement_startend2(lineno, node):
import ast
# flatten all statements and except handlers into one lineno-list
# AST's line numbers start indexing at 1
l = []
for x in ast.walk(node):
if isinstance(x, _ast.stmt) or isinstance(x, _ast.ExceptHandler):
l.append(x.lineno - 1)
for name in "finalbody", "orelse":
val = getattr(x, name, None)
if val:
# treat the finally/orelse part as its own statement
l.append(val[0].lineno - 1 - 1)
l.sort()
insert_index = bisect_right(l, lineno)
start = l[insert_index - 1]
if insert_index >= len(l):
end = None
else:
end = l[insert_index]
return start, end
def get_statement_startend2(lineno, node):
import ast
# flatten all statements and except handlers into one lineno-list
# AST's line numbers start indexing at 1
l = []
for x in ast.walk(node):
if isinstance(x, _ast.stmt) or isinstance(x, _ast.ExceptHandler):
l.append(x.lineno - 1)
for name in "finalbody", "orelse":
val = getattr(x, name, None)
if val:
# treat the finally/orelse part as its own statement
l.append(val[0].lineno - 1 - 1)
l.sort()
insert_index = bisect_right(l, lineno)
start = l[insert_index - 1]
if insert_index >= len(l):
end = None
else:
end = l[insert_index]
return start, end
def linerange(node):
"""Get line number range from a node."""
strip = {"body": None, "orelse": None,
"handlers": None, "finalbody": None}
for key in strip.keys():
if hasattr(node, key):
strip[key] = getattr(node, key)
setattr(node, key, [])
lines_min = 9999999999
lines_max = -1
for n in ast.walk(node):
if hasattr(n, 'lineno'):
lines_min = min(lines_min, n.lineno)
lines_max = max(lines_max, n.lineno)
for key in strip.keys():
if strip[key] is not None:
setattr(node, key, strip[key])
if lines_max > -1:
return list(range(lines_min, lines_max + 1))
return [0, 1]
def get_statement_startend2(lineno, node):
import ast
# flatten all statements and except handlers into one lineno-list
# AST's line numbers start indexing at 1
l = []
for x in ast.walk(node):
if isinstance(x, _ast.stmt) or isinstance(x, _ast.ExceptHandler):
l.append(x.lineno - 1)
for name in "finalbody", "orelse":
val = getattr(x, name, None)
if val:
# treat the finally/orelse part as its own statement
l.append(val[0].lineno - 1 - 1)
l.sort()
insert_index = bisect_right(l, lineno)
start = l[insert_index - 1]
if insert_index >= len(l):
end = None
else:
end = l[insert_index]
return start, end
def get_statement_startend2(lineno, node):
import ast
# flatten all statements and except handlers into one lineno-list
# AST's line numbers start indexing at 1
l = []
for x in ast.walk(node):
if isinstance(x, _ast.stmt) or isinstance(x, _ast.ExceptHandler):
l.append(x.lineno - 1)
for name in "finalbody", "orelse":
val = getattr(x, name, None)
if val:
# treat the finally/orelse part as its own statement
l.append(val[0].lineno - 1 - 1)
l.sort()
insert_index = bisect_right(l, lineno)
start = l[insert_index - 1]
if insert_index >= len(l):
end = None
else:
end = l[insert_index]
return start, end
def get_statement_startend2(lineno, node):
import ast
# flatten all statements and except handlers into one lineno-list
# AST's line numbers start indexing at 1
l = []
for x in ast.walk(node):
if isinstance(x, _ast.stmt) or isinstance(x, _ast.ExceptHandler):
l.append(x.lineno - 1)
for name in "finalbody", "orelse":
val = getattr(x, name, None)
if val:
# treat the finally/orelse part as its own statement
l.append(val[0].lineno - 1 - 1)
l.sort()
insert_index = bisect_right(l, lineno)
start = l[insert_index - 1]
if insert_index >= len(l):
end = None
else:
end = l[insert_index]
return start, end
def get_statement_startend2(lineno, node):
import ast
# flatten all statements and except handlers into one lineno-list
# AST's line numbers start indexing at 1
l = []
for x in ast.walk(node):
if isinstance(x, _ast.stmt) or isinstance(x, _ast.ExceptHandler):
l.append(x.lineno - 1)
for name in "finalbody", "orelse":
val = getattr(x, name, None)
if val:
# treat the finally/orelse part as its own statement
l.append(val[0].lineno - 1 - 1)
l.sort()
insert_index = bisect_right(l, lineno)
start = l[insert_index - 1]
if insert_index >= len(l):
end = None
else:
end = l[insert_index]
return start, end
def get_orig_line_from_s_orig(s_orig, line_no):
if(line_no == None):
return -1
node = py_ast.get_ast(s_orig)
nodeList = [i for i in ast.walk(node) if (hasattr(i, 'lineno') and hasattr(i, 'orig_lineno') and i.lineno == line_no)]
if(len(nodeList) == 0):
#print("------ get_orig_line_from_s_orig begin--------")
#print(s_orig)
#print(line_no)
#print("------ get_orig_line_from_s_orig end--------")
#print("______________")
#print("cannot find lineno")
#print("______________")
#node = preprocess.add_str_node(node)
#nodeList2 = [i for i in ast.walk(node) if (hasattr(i, 'lineno') and hasattr(i, 'orig_lineno') and i.lineno == line_no)]
#if(len(nodeList2) == 0):
return line_no
return nodeList[0].orig_lineno
def get_ast(source_prog):
"""Returns the ast of the program, with comments converted into string
literals.
Args:
source_prog, string, string version of the source code
"""
wrapped_str = comment_to_str(source_prog, TRANS_PREFIXES)
node = ast.parse(wrapped_str)
add_parent_info(node)
nodeList = [i for i in ast.walk(node) if (isinstance(i, ast.stmt))]
for i in nodeList:
if(hasattr(i,'lineno')):
#i.orig_lineno = 1
temp = getLineNum(i)
if(temp != -1):
i.orig_lineno = temp
#a = 1
nodeList = [i for i in ast.walk(node)]
for i in nodeList:
if(hasattr(i,'parent')):
delattr(i, 'parent')
return node
def main():
parser = argparse.ArgumentParser()
parser.add_argument('dir')
args = parser.parse_args()
n_err = 0
for dir, _, files in os.walk(args.dir):
for file in files:
_, ext = os.path.splitext(file)
if not ext == '.py':
continue
path = os.path.join(dir, file)
lines = open(path).readlines()
for lineno, msg in check(''.join(lines)):
print('{:s}:{:d} : {:s}'.format(path, lineno, msg))
print(lines[lineno - 1])
n_err += 1
if n_err > 0:
sys.exit('{:d} style errors are found.'.format(n_err))
def check_nesting(self, **kwargs):
"""Inspect the code for too much nested expressions."""
try:
max_nesting = kwargs['max_nesting']
except KeyError:
return
# Traverse the nodes and find those that are nested
# (have 'body' attribute).
nodes = [(node, node.lineno) for node
in ast.walk(self.parsed_code.body[0])
if hasattr(node, 'body')]
nesting_level = len(nodes)
if nesting_level > max_nesting:
# The line number where the error was found
# is the next one (thus + 1):
line_number = nodes[-1][1] + 1
self.issues[line_number].add(
self.code_errors.nesting_too_deep(
nesting_level, max_nesting
)
)