def generate_pyast(code):
"""
Parses the code and creates a structure of lists and dicts only.
:param code: the code as a string
:return: a structure of lists and dicts only, representing the ast of code
"""
import ast
def transform_ast(code_ast):
if isinstance(code_ast, ast.AST):
# noinspection PyProtectedMember
node = {to_camelcase(k): transform_ast(getattr(code_ast, k)) for k in code_ast._fields}
node['node_type'] = to_camelcase(code_ast.__class__.__name__)
return node
elif isinstance(code_ast, list):
return [transform_ast(el) for el in code_ast]
else:
return code_ast
return transform_ast(ast.parse(code))
python类AST的实例源码
def ast_to_code(node: ast.AST, old_code: types.CodeType = None, file: str = None) -> types.CodeType:
"""
Compile node object to code.
"""
if node and not isinstance(node, ast.AST):
raise TypeError('Unexpected type for node: {}'.format(str(type(node))))
if old_code and not isinstance(old_code, types.CodeType):
raise TypeError('Unexpected type for old_module: {}'.format(str(type(old_code))))
result = old_code
if node:
file = file or (inspect.getfile(old_code) if old_code else None)
result = _call_with_frames_removed\
( compile
, source = node
, filename = file or '<file>'
, mode = 'exec'
, dont_inherit = True
)
elif not old_code:
raise ValueError('Not specified value')
return result
def ast_to_class(node: ast.AST, old_class: type = None, file: str = None) -> type:
"""
:param node:
:param old_class:
:param file:
:return:
"""
if node and not isinstance(node, (ast.Module, ast.ClassDef)):
raise TypeError('Unexpected type for node: {}'.format(str(type(node))))
if old_class and not isinstance(old_class, type):
raise TypeError('Unexpected type for old_class: {}'.format(str(type(old_class))))
result = old_class
# @TODO:
raise NotImplementedError
return NotImplemented
def generatePathToId(a, id, globalId=None):
if not isinstance(a, ast.AST):
return None
if hasattr(a, "global_id") and a.global_id == id:
if globalId == None or (hasattr(a, "variableGlobalId") and a.variableGlobalId == globalId):
return []
for field in a._fields:
attr = getattr(a, field)
if type(attr) == list:
for i in range(len(attr)):
path = generatePathToId(attr[i], id, globalId)
if path != None:
path.append(i)
path.append((field, astNames[type(a)]))
return path
else:
path = generatePathToId(attr, id, globalId)
if path != None:
path.append((field, astNames[type(a)]))
return path
return None
def createNameMap(a, d=None):
if d == None:
d = { }
if not isinstance(a, ast.AST):
return d
if type(a) == ast.Module: # Need to go through the functions backwards to make this right
for i in range(len(a.body) - 1, -1, -1):
createNameMap(a.body[i], d)
return d
if type(a) in [ast.FunctionDef, ast.ClassDef]:
if hasattr(a, "originalId") and a.name not in d:
d[a.name] = a.originalId
elif type(a) == ast.arg:
if hasattr(a, "originalId") and a.arg not in d:
d[a.arg] = a.originalId
return d
elif type(a) == ast.Name:
if hasattr(a, "originalId") and a.id not in d:
d[a.id] = a.originalId
return d
for child in ast.iter_child_nodes(a):
createNameMap(child, d)
return d
def findId(a, id):
if hasattr(a, "global_id") and a.global_id == id:
return a
if type(a) == list:
for child in a:
tmp = findId(child, id)
if tmp != None:
return tmp
return None
if not isinstance(a, ast.AST):
return None
for child in ast.iter_child_nodes(a):
tmp = findId(child, id)
if tmp != None:
return tmp
return None
def getSubtreeContext(super, sub):
if not isinstance(super, ast.AST):
return None
for field in super._fields:
attr = getattr(super, field)
if type(attr) == list:
for i in range(len(attr)):
if compareASTs(attr[i], sub, checkEquality=True) == 0:
return (attr, i, attr[i])
else:
tmp = getSubtreeContext(attr[i], sub)
if tmp != None:
return tmp
else:
if compareASTs(attr, sub, checkEquality=True) == 0:
return (super, field, attr)
else:
tmp = getSubtreeContext(attr, sub)
if tmp != None:
return tmp
return None
def basicTypeSpecialFunction(cv):
"""If you're in a number or string (which has no metadata), move up to the AST to make the special functions work."""
if isinstance(cv, SwapVector) or isinstance(cv, MoveVector):
return cv
if (cv.path[0] in [('n', 'Number'), ('s', 'String'), ('id', 'Name'), ('arg', 'Argument'),
('value', 'Name Constant'), ('s', 'Bytes'), ('name', 'Alias')]):
cvCopy = cv.deepcopy()
cv.oldSubtree = deepcopy(cvCopy.traverseTree(cv.start))
if cv.path[0] == ('n', 'Number'):
cv.newSubtree = ast.Num(cv.newSubtree)
elif cv.path[0] == ('s', 'String'):
cv.newSubtree = ast.Str(cv.newSubtree)
elif cv.path[0] == ('id', 'Name'):
cv.newSubtree = ast.Name(cv.newSubtree, cv.oldSubtree.ctx)
elif cv.path[0] == ('arg', 'Argument'):
cv.newSubtree = ast.arg(cv.newSubtree, cv.oldSubtree.annotation)
elif cv.path[0] == ('value', 'Name Constant'):
cv.newSubtree = ast.NameConstant(cv.newSubtree)
elif cv.path[0] == ('s', 'Bytes'):
cv.newSubtree = ast.Bytes(cv.newSubtree)
elif cv.path[0] == ('name', 'Alias'):
cv.newSubtree = ast.alias(cv.newSubtree, cv.oldSubtree.asname)
cv.path = cv.path[1:]
return cv
def cleanupRanges(a):
"""Remove any range shenanigans, because Python lets you include unneccessary values"""
if not isinstance(a, ast.AST):
return a
if type(a) == ast.Call:
if type(a.func) == ast.Name:
if a.func.id in ["range"]:
if len(a.args) == 3:
# The step defaults to 1!
if type(a.args[2]) == ast.Num and a.args[2].n == 1:
a.args = a.args[:-1]
if len(a.args) == 2:
# The start defaults to 0!
if type(a.args[0]) == ast.Num and a.args[0].n == 0:
a.args = a.args[1:]
return applyToChildren(a, cleanupRanges)
def cleanupSlices(a):
"""Remove any slice shenanigans, because Python lets you include unneccessary values"""
if not isinstance(a, ast.AST):
return a
if type(a) == ast.Subscript:
if type(a.slice) == ast.Slice:
# Lower defaults to 0
if a.slice.lower != None and type(a.slice.lower) == ast.Num and a.slice.lower.n == 0:
a.slice.lower = None
# Upper defaults to len(value)
if a.slice.upper != None and type(a.slice.upper) == ast.Call and \
type(a.slice.upper.func) == ast.Name and a.slice.upper.func.id == "len":
if compareASTs(a.value, a.slice.upper.args[0], checkEquality=True) == 0:
a.slice.upper = None
# Step defaults to 1
if a.slice.step != None and type(a.slice.step) == ast.Num and a.slice.step.n == 1:
a.slice.step = None
return applyToChildren(a, cleanupSlices)
def combineConditionals(a):
"""When possible, combine conditional branches"""
if not isinstance(a, ast.AST):
return a
elif type(a) == ast.If:
for i in range(len(a.body)):
a.body[i] = combineConditionals(a.body[i])
for i in range(len(a.orelse)):
a.orelse[i] = combineConditionals(a.orelse[i])
# if a: if b: x can be - if a and b: x
if (len(a.orelse) == 0) and (len(a.body) == 1) and \
(type(a.body[0]) == ast.If) and (len(a.body[0].orelse) == 0):
a.test = ast.BoolOp(ast.And(combinedConditionalOp=True), [a.test, a.body[0].test], combinedConditional=True)
a.body = a.body[0].body
# if a: x elif b: x can be - if a or b: x
elif (len(a.orelse) == 1) and \
(type(a.orelse[0]) == ast.If) and (len(a.orelse[0].orelse) == 0):
if compareASTs(a.body, a.orelse[0].body, checkEquality=True) == 0:
a.test = ast.BoolOp(ast.Or(combinedConditionalOp=True), [a.test, a.orelse[0].test], combinedConditional=True)
a.orelse = []
return a
else:
return applyToChildren(a, combineConditionals)
def occursIn(sub, super):
"""Does the first AST occur as a subtree of the second?"""
superStatementTypes = [ ast.Module, ast.Interactive, ast.Suite,
ast.FunctionDef, ast.ClassDef, ast.For,
ast.While, ast.If, ast.With, ast.Try,
ast.ExceptHandler ]
if (not isinstance(super, ast.AST)):
return False
if type(sub) == type(super) and compareASTs(sub, super, checkEquality=True) == 0:
return True
# we know that a statement can never occur in an expression
# (or in a non-statement-holding statement), so cut the search off now to save time.
if isStatement(sub) and type(super) not in superStatementTypes:
return False
for child in ast.iter_child_nodes(super):
if occursIn(sub, child):
return True
return False
def gatherAllParameters(a, keep_orig=True):
"""Gather all parameters in the tree. Names are returned along
with their original names (which are used in variable mapping)"""
if type(a) == list:
allIds = set()
for line in a:
allIds |= gatherAllVariables(line)
return allIds
if not isinstance(a, ast.AST):
return set()
allIds = set()
for node in ast.walk(a):
if type(node) == ast.arg:
origName = node.originalId if (keep_orig and hasattr(node, "originalId")) else None
allIds |= set([(node.arg, origName)])
return allIds
def getAllGlobalNames(a):
# Finds all names that can be accessed at the global level in the AST
if type(a) != ast.Module:
return []
names = []
for obj in a.body:
if type(obj) in [ast.FunctionDef, ast.ClassDef]:
names.append(obj.name)
elif type(obj) in [ast.Assign, ast.AugAssign]:
targets = obj.targets if type(obj) == ast.Assign else [obj.target]
for target in obj.targets:
if type(target) == ast.Name:
names.append(target.id)
elif type(target) in [ast.Tuple, ast.List]:
for elt in target.elts:
if type(elt) == ast.Name:
names.append(elt.id)
elif type(obj) in [ast.Import, ast.ImportFrom]:
for module in obj.names:
names.append(module.asname if module.asname != None else module.name)
return names
def doCompare(op, left, right):
"""Perform the given AST comparison on the values"""
top = type(op)
if top == ast.Eq:
return left == right
elif top == ast.NotEq:
return left != right
elif top == ast.Lt:
return left < right
elif top == ast.LtE:
return left <= right
elif top == ast.Gt:
return left > right
elif top == ast.GtE:
return left >= right
elif top == ast.Is:
return left is right
elif top == ast.IsNot:
return left is not right
elif top == ast.In:
return left in right
elif top == ast.NotIn:
return left not in right
def test_AST_objects(self):
if not support.check_impl_detail():
# PyPy also provides a __dict__ to the ast.AST base class.
return
x = ast.AST()
self.assertEqual(x._fields, ())
with self.assertRaises(AttributeError):
x.vararg
with self.assertRaises(AttributeError):
x.foobar = 21
with self.assertRaises(AttributeError):
ast.AST(lineno=2)
with self.assertRaises(TypeError):
# "_ast.AST constructor takes 0 positional arguments"
ast.AST(2)
def test_AST_objects(self):
x = ast.AST()
self.assertEqual(x._fields, ())
with self.assertRaises(AttributeError):
x.vararg
with self.assertRaises(AttributeError):
x.foobar = 21
with self.assertRaises(AttributeError):
ast.AST(lineno=2)
with self.assertRaises(TypeError):
# "_ast.AST constructor takes 0 positional arguments"
ast.AST(2)
def test_AST_objects(self):
x = ast.AST()
self.assertEqual(x._fields, ())
with self.assertRaises(AttributeError):
x.vararg
with self.assertRaises(AttributeError):
x.foobar = 21
with self.assertRaises(AttributeError):
ast.AST(lineno=2)
with self.assertRaises(TypeError):
# "_ast.AST constructor takes 0 positional arguments"
ast.AST(2)
def visit_FunctionDef(self, node):
has_starargs = False
args = list(node.args.args)
if node.args.vararg:
if isinstance(node.args.vararg, ast.AST): # pragma: no cover (py3)
args.append(node.args.vararg)
has_starargs = True
if node.args.kwarg:
if isinstance(node.args.kwarg, ast.AST): # pragma: no cover (py3)
args.append(node.args.kwarg)
has_starargs = True
py3_kwonlyargs = getattr(node.args, 'kwonlyargs', None)
if py3_kwonlyargs: # pragma: no cover (py3)
args.extend(py3_kwonlyargs)
has_starargs = True
arg_offsets = {_to_offset(arg) for arg in args}
if arg_offsets:
key = Offset(node.lineno, node.col_offset)
self.funcs[key] = Func(node, has_starargs, arg_offsets)
self.generic_visit(node)
def visit_FunctionDef(self, node, phase):
'''Visitor for AST FunctionDef nodes
add relevant information about the node to
the context for use in tests which inspect function definitions.
Add the function name to the current namespace for all descendants.
:param node: The node that is being inspected
:return: -
'''
self.context['function'] = node
qualname = self.namespace + '.' + b_utils.get_func_name(node)
name = qualname.split('.')[-1]
self.context['qualname'] = qualname
self.context['name'] = name
# For all child nodes and any tests run, add this function name to
# current namespace
self.namespace = b_utils.namespace_path_join(self.namespace, name)
self.update_scores(self.tester.run_tests(self.context, 'FunctionDef', phase=phase))
def visit_Call(self, node, phase):
'''Visitor for AST Call nodes
add relevant information about the node to
the context for use in tests which inspect function calls.
:param node: The node that is being inspected
:return: -
'''
self.context['call'] = node
qualname = b_utils.get_call_name(node, self.import_aliases)
name = qualname.split('.')[-1]
self.context['qualname'] = qualname
self.context['name'] = name
self.update_scores(self.tester.run_tests(self.context, 'Call', phase=phase))
def generic_visit(self, node, phase=None):
"""Drive the visitor."""
phase = phase or constants.PRIMARY
for _, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if not isinstance(item, ast.AST):
continue
if not self.pre_visit(item):
continue
self.visit(item, phase=phase)
self.generic_visit(item, phase=phase)
self.post_visit(item)
elif isinstance(value, ast.AST):
if not self.pre_visit(value):
continue
self.visit(value, phase=phase)
self.generic_visit(value, phase=phase)
self.post_visit(value)
def check_call_visitor(self, visitor):
tree = ast.parse("1+1")
with self.assertRaises(Exception) as cm:
visitor.visit(tree)
binop = tree.body[0].value
what = ast.dump(binop)
self.assertEqual(str(cm.exception),
'error at <string>:1 on visiting %s: bug' % what)
# Test truncature of the AST dump
with mock.patch('fatoptimizer.tools.COMPACT_DUMP_MAXLEN', 5):
with self.assertRaises(Exception) as cm:
visitor.visit(tree)
what = 'BinOp(...)'
self.assertEqual(str(cm.exception),
'error at <string>:1 on visiting %s: bug' % what)
def generic_visit(self, node: ast.AST) -> None:
"""Called if no explicit visitor function exists for a node."""
for _field, value in ast.iter_fields(node):
if self.should_type_check:
break
if isinstance(value, list):
for item in value:
if self.should_type_check:
break
if isinstance(item, ast.AST):
self.visit(item)
elif isinstance(value, ast.AST):
self.visit(value)
# Generic mypy error
def visit_For(self, node):
for field, value in ast.iter_fields(node):
flag_cache = self.flag
if field == 'target':
self.flag = 'lhs'
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
self.visit(item)
elif isinstance(value, ast.AST):
self.visit(value)
self.flag = flag_cache
def visit_Assign(self, node):
for field, value in ast.iter_fields(node):
if field == 'targets':
self.flag = 'lhs'
elif field == 'value':
self.flag = 'rhs'
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
self.visit(item)
elif isinstance(value, ast.AST):
self.visit(value)
self.flag = None
def visit_AugAssign(self, node):
for field, value in ast.iter_fields(node):
if field == 'target':
self.flag = 'lhs'
elif field == 'value':
self.flag = 'rhs'
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
self.visit(item)
elif isinstance(value, ast.AST):
self.visit(value)
self.flag = None
def to_source_any(n):
"""
Convert AST node to string, handling all node types, without fixing comments.
"""
try:
return astor.to_source(n)
except AttributeError:
pass
cls = n.__class__
if cls in astor.misc.all_symbols:
return astor.misc.all_symbols[cls]
def wrap(s):
return '___' + s + '___'
extra_d = {ast.Load: wrap('load'),
ast.Store: wrap('store'),
ast.Del: wrap('del'),
ast.AugLoad: wrap('augload'),
ast.AugStore: wrap('augstore'),
ast.Param: wrap('param'),
ast.keyword: wrap('keyword')}
if cls in extra_d:
return extra_d[cls]
raise AttributeError('unknown node type {}'.format(cls))
def add_after_node(ast_root, after_node, node_to_add):
"""Same idea as add_before_node, but in this case add it after after_node
"""
node, parent = find_node_recursive(ast_root, after_node)
if node is None:
raise ValueError("Node %s not found in ast: %s" % (
str(after_node),
dump_ast(after_node)))
for field, value in ast.iter_fields(parent):
if isinstance(value, list):
for i in range(len(value)):
if isinstance(value[i], ast.AST) and \
nodes_are_equal(value[i], node):
value.insert(i + 1, node_to_add)
return
def get_all_nodes_in_bfs_order(ast_root):
q = [ast_root]
result = []
while len(q) > 0:
top = q.pop(0)
result.append(top)
for field, value in ast.iter_fields(top):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
q.append(item)
elif isinstance(value, ast.AST):
result.append(value)
return result