def find_parent_field(parent, node):
"""
assume parent is the ancestor of node, return the field that derives to node in parent
"""
current = node
while current.parent != parent:
current = current.parent
for field, value in ast.iter_fields(parent):
if isinstance(value, list):
if current in value:
return field
elif current == value:
return field
return None
python类iter_fields()的实例源码
def visit_Subscript(self, node):
"""
tag variables that are followed by index as array
"""
try:
if node.value.id not in self.array_name:
self.array_name.append(node.value.id)
except:
pass
for field, value in ast.iter_fields(node):
flag_cache = self.flag
#index_cache = self.index_name
#self.index_name = node.value
if field == 'slice':
self.flag = None
elif field == 'ctx':
self.flag = None
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
self.visit(item)
elif isinstance(value, ast.AST):
self.visit(value)
self.flag = flag_cache
#self.index_name = index_cache
def visit_Subscript(self, node):
ind_node = node.slice
for field, value in ast.iter_fields(ind_node):
if field == 'value':
if isinstance(value, ast.Tuple):
try:
for i in range(len(value.elts)):
ind_value = py_ast.dump_ast(value.elts[i])
self.value.append(ind_value)
self.write.append([])
self.visit(value.elts[i])
except:
pass
else:
try:
ind_value = py_ast.dump_ast(value)
self.value.append(ind_value)
self.write.append([])
self.visit(value)
except:
pass
elif field == 'dims':
try:
for i in range(len(value)):
ind_value = py_ast.dump_ast(value[i])
self.value.append(ind_value)
self.write.append([])
self.visit(value[i])
except:
pass
def replace_node(ast_root, node_to_replace, replacement_node):
"""Replaces node_to_replace with replacement_node in the ast.
"""
# first, search for the node
#node, parent = find_node_recursive(ast_root, node_to_replace)
if not hasattr(node_to_replace, 'parent'):
add_parent_info(ast_root)
# if you can't find the node you want to replace, raise an error
if not hasattr(node_to_replace, 'parent'):
raise ValueError("Node %s not found in ast: %s" % (
str(node_to_replace),
dump_ast(node_to_replace)))
parent = node_to_replace.parent
# otherwise, find the node, within its parent, and replace it
for field, value in ast.iter_fields(parent):
if isinstance(value, list):
for i in range(len(value)):
if isinstance(value[i], ast.AST) and \
nodes_are_equal(value[i], node_to_replace):
value[i] = replacement_node
return
elif isinstance(value, ast.AST) and nodes_are_equal(value, node_to_replace):
setattr(parent, field, replacement_node)
setattr(replacement_node, 'parent', parent)
return
def add_before_node(ast_root, before_node, node_to_add):
"""Attempts to add node_to_add before before_node
For example, if you had the code:
def foo(j):
for i in range(j):
print(i)
and before_node was "for i in range(j):" and node_to_add was "print(2)",
the result would be:
def foo(i):
print(2)
for i in range(j):
print(i)
"""
node, parent = find_node_recursive(ast_root, before_node)
if node is None:
raise ValueError("Node %s not found in ast: %s" % (
str(before_node),
dump_ast(before_node)))
for field, value in ast.iter_fields(parent):
if isinstance(value, list):
for i in range(len(value)):
if isinstance(value[i], ast.AST) and \
nodes_are_equal(value[i], node):
value.insert(i, node_to_add)
return
def _LiteralEval(value):
"""Parse value as a Python literal, or container of containers and literals.
First the AST of the value is updated so that bare-words are turned into
strings. Then the resulting AST is evaluated as a literal or container of
only containers and literals.
This allows for the YAML-like syntax {a: b} to represent the dict {'a': 'b'}
Args:
value: A string to be parsed as a literal or container of containers and
literals.
Returns:
The Python value representing the value arg.
Raises:
ValueError: If the value is not an expression with only containers and
literals.
SyntaxError: If the value string has a syntax error.
"""
root = ast.parse(value, mode='eval')
if isinstance(root.body, ast.BinOp):
raise ValueError(value)
for node in ast.walk(root):
for field, child in ast.iter_fields(node):
if isinstance(child, list):
for index, subchild in enumerate(child):
if isinstance(subchild, ast.Name):
child[index] = _Replacement(subchild)
elif isinstance(child, ast.Name):
replacement = _Replacement(child)
node.__setattr__(field, replacement)
# ast.literal_eval supports the following types:
# strings, bytes, numbers, tuples, lists, dicts, sets, booleans, and None
# (bytes and set literals only starting with Python 3.2)
return ast.literal_eval(root)
def generic_visit(self, node):
for field, old_value in ast.iter_fields(node):
old_value = getattr(node, field, None)
if isinstance(old_value, list):
old_value[:] = self.visit_list(old_value)
elif isinstance(old_value, ast.AST):
new_node = self.visit(old_value)
if new_node is None:
delattr(node, field)
else:
setattr(node, field, new_node)
return node
def generic_visit(self, node):
"""Called if no explicit visitor function exists for a node."""
for field, value in ast.iter_fields(node):
if isinstance(value, list):
self.visit_list(value)
elif isinstance(value, ast.AST):
self.visit(value)
def generic_visit(self, node):
"""Called if no explicit visitor function exists for a node."""
res = []
for field, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
vv = self.visit(item)
res.append(vv)
elif isinstance(value, dict):
for k, v in value:
res.append(self.visit(v))
res.append(self.visit(value))
if not res:
print("visiting", node)
return list(filter(None, res))
def test_iter_fields(self):
node = ast.parse('foo()', mode='eval')
d = dict(ast.iter_fields(node.body))
self.assertEqual(d.pop('func').id, 'foo')
self.assertEqual(d, {'keywords': [], 'kwargs': None,
'args': [], 'starargs': None})
def test_iter_fields(self):
node = ast.parse('foo()', mode='eval')
d = dict(ast.iter_fields(node.body))
self.assertEqual(d.pop('func').id, 'foo')
self.assertEqual(d, {'keywords': [], 'kwargs': None,
'args': [], 'starargs': None})
def test_iter_fields(self):
node = ast.parse('foo()', mode='eval')
d = dict(ast.iter_fields(node.body))
self.assertEqual(d.pop('func').id, 'foo')
self.assertEqual(d, {'keywords': [], 'kwargs': None,
'args': [], 'starargs': None})
def visit_Compare(self, node):
if len(node.ops) not in (1, 2,):
raise SyntaxError("ast.Compare with more than 2 ops: %s is not supported" % node)
(_, left), (_, ops), (_, comps) = ast.iter_fields(node)
self.visit(left)
left = self.data.pop()
comparators = list()
for comparator in comps:
self.visit(comparator)
comparators.append(self.data.pop())
if len(ops) == 1:
right = comparators[0]
cls = criteria_class.lookup(ast_op_to_criteria.lookup(type(ops[0])))
criteria = cls(left, *right) if type(right) in (list, tuple,) else cls(left, right)
self.data.append(criteria)
else:
lower = left
lower_op = ast_op_to_operator.lookup(type(ops[0]))
one = comparators[0]
upper_op = ast_op_to_operator.lookup(type(ops[1]))
upper = comparators[1]
criteria = criteria_class.instance(Const.Between, lower, one, upper, lower_op, upper_op)
self.data.append(criteria)
def visit_Call(self, node):
fields = {k: v for k, v in ast.iter_fields(node) if v}
self.visit(fields[Const.func])
name, args, kwargs = self.data.pop(), list(), collections.OrderedDict()
func = SyntaxAstCallExtender.find_deserializer(name)
if not func:
raise SyntaxError("%s is not supported" % name)
if Const.args in fields:
for arg in fields[Const.args]:
self.visit(arg)
args.append(self.data.pop())
if Const.keywords in fields:
for keyword in fields[Const.keywords]:
(_, key), (_, value) = ast.iter_fields(keyword)
self.visit(value)
kwargs[key] = self.data.pop()
if Const.kwargs in fields:
(_, knodes), (_, vnodes) = ast.iter_fields(fields[Const.kwargs])
for knode, vnode in zip(knodes, vnodes):
self.visit(knode)
key = self.data.pop()
self.visit(vnode)
value = self.data.pop()
kwargs[key] = value
obj = func(*args, **kwargs)
self.data.append(obj)
def generic_visit(self, node, inner=False):
for _, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
self.visit(item, inner)
elif isinstance(value, ast.AST):
self.visit(value, inner)
def generic_visit(self, node):
for field, old_value in ast.iter_fields(node):
if isinstance(old_value, list):
generator = self.generic_visit_list(old_value)
elif isinstance(old_value, ast.AST):
generator = self.generic_visit_real_node(node, field, old_value)
else:
generator = []
for _ in generator:
yield node
def test_iter_fields(self):
node = ast.parse('foo()', mode='eval')
d = dict(ast.iter_fields(node.body))
self.assertEqual(d.pop('func').id, 'foo')
self.assertEqual(d, {'keywords': [], 'kwargs': None,
'args': [], 'starargs': None})
def str_node(node):
if isinstance(node, ast.AST):
fields = [(name, str_node(val)) for name, val in ast.iter_fields(node) if name not in ('left', 'right')]
rv = '%s(%s' % (node.__class__.__name__, ', '.join('%s=%s' % field for field in fields))
return rv + ')'
else:
return repr(node)
def ast_visit(node, level=0):
print(' ' * level + str_node(node))
for field, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
ast_visit(item, level=level+1)
elif isinstance(value, ast.AST):
ast_visit(value, level=level+1)
def filter_arguments(arguments, bound_argnames):
"""
Filters a node containing function arguments (an ``ast.arguments`` object)
to exclude all arguments with the names present in ``bound_arguments``.
Returns the new ``ast.arguments`` node.
"""
assert type(arguments) == ast.arguments
new_params = dict(ast.iter_fields(arguments))
new_params['args'], new_params['defaults'] = filter_arglist(
arguments.args, arguments.defaults, bound_argnames)
new_params['kwonlyargs'], new_params['kw_defaults'] = filter_arglist(
arguments.kwonlyargs, arguments.kw_defaults, bound_argnames)
vararg_name = arguments.vararg.arg if arguments.vararg is not None else None
kwarg_name = arguments.kwarg.arg if arguments.kwarg is not None else None
if vararg_name is not None and vararg_name in bound_argnames:
new_params['vararg'] = None
if kwarg_name is not None and kwarg_name in bound_argnames:
new_params['kwarg'] = None
return ast.arguments(**new_params)