def generic_visit(self, node, phase=None):
"""Drive the visitor."""
phase = phase or constants.PRIMARY
for _, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if not isinstance(item, ast.AST):
continue
if not self.pre_visit(item):
continue
self.visit(item, phase=phase)
self.generic_visit(item, phase=phase)
self.post_visit(item)
elif isinstance(value, ast.AST):
if not self.pre_visit(value):
continue
self.visit(value, phase=phase)
self.generic_visit(value, phase=phase)
self.post_visit(value)
python类iter_fields()的实例源码
def generic_visit(self, node: ast.AST) -> None:
"""Called if no explicit visitor function exists for a node."""
for _field, value in ast.iter_fields(node):
if self.should_type_check:
break
if isinstance(value, list):
for item in value:
if self.should_type_check:
break
if isinstance(item, ast.AST):
self.visit(item)
elif isinstance(value, ast.AST):
self.visit(value)
# Generic mypy error
def is_before(node1, node2):
"""
checks if definately appears before node2
"""
current = node2
while current is not None:
try:
if current.parent == node1.parent:
for field, value in ast.iter_fields(node1.parent):
if value == current or value == node1:
return False
elif isinstance(value, list) and current in value and node1 in value:
list_index1 = value.index(node1)
list_index2 = value.index(current)
if list_index2 > list_index1:
return True
except:
pass
current = current.parent
return False
def insert_before_parent_list_fixed(node_list, s):
"""
insert a string before a certain ast node if the node's parent has a field which is a list that contains the node
"""
for i in range(len(node_list)):
parent = node_list[i].parent
for field, value in ast.iter_fields(parent):
if isinstance(value, list):
try:
index = value.index(node_list[i])
value[index:index] = ast.parse(s).body
except:
continue
return
def visit_For(self, node):
for field, value in ast.iter_fields(node):
flag_cache = self.flag
if field == 'target':
self.flag = 'lhs'
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
self.visit(item)
elif isinstance(value, ast.AST):
self.visit(value)
self.flag = flag_cache
def visit_Assign(self, node):
for field, value in ast.iter_fields(node):
if field == 'targets':
self.flag = 'lhs'
elif field == 'value':
self.flag = 'rhs'
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
self.visit(item)
elif isinstance(value, ast.AST):
self.visit(value)
self.flag = None
def add_after_node(ast_root, after_node, node_to_add):
"""Same idea as add_before_node, but in this case add it after after_node
"""
node, parent = find_node_recursive(ast_root, after_node)
if node is None:
raise ValueError("Node %s not found in ast: %s" % (
str(after_node),
dump_ast(after_node)))
for field, value in ast.iter_fields(parent):
if isinstance(value, list):
for i in range(len(value)):
if isinstance(value[i], ast.AST) and \
nodes_are_equal(value[i], node):
value.insert(i + 1, node_to_add)
return
def get_all_nodes_in_bfs_order(ast_root):
q = [ast_root]
result = []
while len(q) > 0:
top = q.pop(0)
result.append(top)
for field, value in ast.iter_fields(top):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
q.append(item)
elif isinstance(value, ast.AST):
result.append(value)
return result
def _walk_ast(self, node, top=False):
if not hasattr(node, 'parent'):
node.parent = None
node.parents = []
for field, value in ast.iter_fields(node):
if isinstance(value, list):
for index, item in enumerate(value):
if isinstance(item, ast.AST):
self._walk_ast(item)
self._set_parnt_fields(item, node, field, index)
elif isinstance(value, ast.AST):
self._walk_ast(value)
self._set_parnt_fields(value, node, field)
if top:
return ast.walk(node)
def run_in_context(code, context, defs={}):
ast_ = ast.parse(code, '<code>', 'exec')
last_expr = None
last_def_name = None
for field_ in ast.iter_fields(ast_):
if 'body' != field_[0]:
continue
if len(field_[1]) > 0:
le = field_[1][-1]
if isinstance(le, ast.Expr):
last_expr = ast.Expression()
last_expr.body = field_[1].pop().value
elif isinstance(le, (ast.FunctionDef, ast.ClassDef)):
last_def_name = le.name
exec(compile(ast_, '<hbi-code>', 'exec'), context, defs)
if last_expr is not None:
return eval(compile(last_expr, '<hbi-code>', 'eval'), context, defs)
elif last_def_name is not None:
return defs[last_def_name]
return None
def extract_option_from_arg_list(options, optname, default_value):
if not options:
return default_value, options
try:
args = list(ast.iter_fields(ast.parse(f"f({options})", mode='eval')))[0][1].keywords
for idx,field in enumerate(args):
if field.arg == optname:
try:
value = eval(compile(ast.Expression(body=field.value), filename="<ast>", mode="eval"))
new_options = ','.join([x for x in options.split(',') if not x.strip().startswith(optname)])
return value, new_options.strip()
except:
raise ValueError(f"A constant value is expected for option {optname}: {options} provided.")
return default_value, options
except SyntaxError as e:
raise ValueError(f"Expect a list of keyword arguments: {options} provided")
def _walk_fields(self, state, node, ctx):
"""
Traverses all fields of an AST node.
"""
if self._transform:
transformed = False
new_fields = {}
new_state = state
for field, value in ast.iter_fields(node):
block_context = field in _BLOCK_FIELDS and type(value) == list
new_state, new_value = self._walk_field(
new_state, value, ctx, block_context=block_context)
if self._transform:
new_fields[field] = new_value
if new_value is not value:
transformed = True
if self._transform and transformed:
return new_state, type(node)(**new_fields)
else:
return new_state, node
def ast2tree(node, include_attrs=True):
def _transform(node):
if isinstance(node, ast.AST):
fields = ((a, _transform(b))
for a, b in ast.iter_fields(node))
if include_attrs:
attrs = ((a, _transform(getattr(node, a)))
for a in node._attributes
if hasattr(node, a))
return (node.__class__.__name__, dict(fields), dict(attrs))
return (node.__class__.__name__, dict(fields))
elif isinstance(node, list):
return [_transform(x) for x in node]
elif isinstance(node, str):
return repr(node)
return node
if not isinstance(node, ast.AST):
raise TypeError('expected AST, got %r' % node.__class__.__name__)
return _transform(node)
def findListId(a, id):
# We want to go one level up to get the list this belongs to
if type(a) == list and len(a) > 0 and hasattr(a[0], "global_id") and a[0].global_id == id:
return a
if type(a) == list:
for item in a:
tmp = findListId(item, id)
if tmp != None:
return tmp
elif isinstance(a, ast.AST):
for (field, val) in ast.iter_fields(a):
tmp = findListId(val, id)
if tmp != None:
return tmp
return None
def dump(node, annotate_fields=True, include_attributes=False):
"""
Return a formatted dump of the tree in *node*. This is mainly useful for
debugging purposes. The returned string will show the names and the values
for fields. This makes the code impossible to evaluate, so if evaluation is
wanted *annotate_fields* must be set to False. Attributes such as line
numbers and column offsets are not dumped by default. If this is wanted,
*include_attributes* can be set to True.
"""
def _format(node):
if isinstance(node, AST):
fields = [(a, _format(b)) for a, b in iter_fields(node)]
rv = '%s(%s' % (node.__class__.__name__, ', '.join(
('%s=%s' % field for field in fields)
if annotate_fields else
(b for a, b in fields)
))
if include_attributes and node._attributes:
rv += fields and ', ' or ' '
rv += ', '.join('%s=%s' % (a, _format(getattr(node, a)))
for a in node._attributes)
return rv + ')'
elif isinstance(node, list):
return '[%s]' % ', '.join(_format(x) for x in node)
return repr(node)
if not isinstance(node, AST):
raise TypeError('expected AST, got %r' % node.__class__.__name__)
return _format(node)
def test_comprehensions(self):
# See https://bitbucket.org/plas/thonny/issues/8/range-marker-doesnt-work-correctly-with
for source in (
"[(key, val) for key, val in ast.iter_fields(node)]",
"((key, val) for key, val in ast.iter_fields(node))",
"{(key, val) for key, val in ast.iter_fields(node)}",
"{key: val for key, val in ast.iter_fields(node)}",
"[[c for c in key] for key, val in ast.iter_fields(node)]"):
m = self.create_mark_checker(source)
m.verify_all_nodes(self)
def test_iter_fields(self):
node = ast.parse('foo()', mode='eval')
d = dict(ast.iter_fields(node.body))
self.assertEqual(d.pop('func').id, 'foo')
self.assertEqual(d, {'keywords': [], 'kwargs': None,
'args': [], 'starargs': None})
def str_ast_node(node):
if isinstance(node, ast.AST):
fields = [(name, str_ast_node(val)) for name, val in ast.iter_fields(node) if name not in ('left', 'right')]
rv = '%s(%s' % (node.__class__.__name__, ', '.join('%s=%s' % field for field in fields))
return rv + ')'
else:
return repr(node)
def test_iter_fields(self):
node = ast.parse('foo()', mode='eval')
d = dict(ast.iter_fields(node.body))
self.assertEqual(d.pop('func').id, 'foo')
self.assertEqual(d, {'keywords': [], 'kwargs': None,
'args': [], 'starargs': None})
def test_iter_fields(self):
node = ast.parse('foo()', mode='eval')
d = dict(ast.iter_fields(node.body))
self.assertEqual(d.pop('func').id, 'foo')
self.assertEqual(d, {'keywords': [], 'kwargs': None,
'args': [], 'starargs': None})
def ast_pretty_dump(node, annotate_fields=True, include_attributes=False, indent=' '):
"""
Originally copied from ast_demo.dump() source code
https://github.com/st3fan/pythoncodeanalysis/blob/master/utils/astpp.py
Return a formatted dump of the tree in *node*. This is mainly useful for
debugging purposes. The returned string will show the names and the values
for fields. This makes the code impossible to evaluate, so if evaluation is
wanted *annotate_fields* must be set to False. Attributes such as line
numbers and column offsets are not dumped by default. If this is wanted,
*include_attributes* can be set to True.
"""
def _format(node, level=0):
if isinstance(node, ast.AST):
fields = [(a, _format(b, level)) for a, b in ast.iter_fields(node)]
if include_attributes and node._attributes:
fields.extend([(a, _format(getattr(node, a), level))
for a in node._attributes])
return ''.join([
node.__class__.__name__,
'(',
', '.join(('%s=%s' % field for field in fields)
if annotate_fields else
(b for a, b in fields)),
')'])
elif isinstance(node, list):
lines = ['[']
lines.extend((indent * (level + 2) + _format(x, level + 2) + ','
for x in node))
if len(lines) > 1:
lines.append(indent * (level + 1) + ']')
else:
lines[-1] += ']'
return '\n'.join(lines)
return repr(node)
if not isinstance(node, ast.AST):
raise TypeError('expected AST, got %r' % node.__class__.__name__)
return _format(node)
def preprocess_nodes(self, node):
"""Run preprocessors on nodes for the visitor."""
for _, value in ast.iter_fields(node):
if isinstance(value, list):
max_idx = len(value) - 1
for idx, item in enumerate(value):
if not isinstance(item, ast.AST):
continue
if idx < max_idx:
setattr(item, 'sibling', value[idx + 1])
else:
setattr(item, 'sibling', None)
setattr(item, 'parent', node)
setattr(item, 'storage', {})
if not self.pre_visit(item, preprocess=True):
continue
self.preprocess_nodes(item)
self.post_visit(item)
elif isinstance(value, ast.AST):
setattr(value, 'sibling', None)
setattr(value, 'parent', node)
setattr(value, 'storage', {})
if not self.pre_visit(value, preprocess=True):
continue
self.preprocess_nodes(value)
self.post_visit(value)
def generic_visit(self, node, container=None):
for field, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
yield from self.visit(item, container)
elif isinstance(value, ast.AST):
yield from self.visit(value, container)
def generic_visit(self, node, container=None):
for field, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
yield from self.visit(item, container)
elif isinstance(value, ast.AST):
yield from self.visit(value, container)
def generic_visit(self, node, container=None):
for field, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
yield from self.visit(item, container)
elif isinstance(value, ast.AST):
yield from self.visit(value, container)
def generic_visit(self, node):
"""Called if no explicit visitor function exists for a node."""
for _, value in ast.iter_fields(node):
if isinstance(value, list):
self._handle_ast_list(value)
for item in value:
if isinstance(item, ast.AST):
self.visit(item)
elif isinstance(value, ast.AST):
self.visit(value)
def test_iter_fields(self):
node = ast.parse('foo()', mode='eval')
d = dict(ast.iter_fields(node.body))
self.assertEqual(d.pop('func').id, 'foo')
self.assertEqual(d, {'keywords': [], 'kwargs': None,
'args': [], 'starargs': None})
def _iter_all_ast(node):
yield node
for field, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
for child in _iter_all_ast(item):
yield child
elif isinstance(value, ast.AST):
for child in _iter_all_ast(value):
yield child
def copy_node(node):
new_node = type(node)()
for field, value in ast.iter_fields(node):
setattr(new_node, field, value)
for attr in node._attributes:
try:
value = getattr(node, attr)
except AttributeError:
pass
else:
setattr(new_node, attr, value)
return new_node
def generic_visit(self, node):
for field, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
res = self.visit(item)
if not res:
return False
elif isinstance(value, ast.AST):
res = self.visit(value)
if not res:
return False
return True