def find_const_value(defnode, arg_str, seen_names):
"""
given arg_str, which usually represents a dimension size of an array
eg: a // 4 + 8
try replace variables with constants
"""
try:
value = eval(arg_str)
return value
except:
dimension_node = py_ast.get_ast(arg_str).body[0].value
namenodes = py_ast.find_all(dimension_node, ast.Name)
names = []
for namenode in namenodes:
if namenode.id not in names:
names.append(namenode.id)
assignnodes = py_ast.find_all(defnode, ast.Assign)
aug_assignnodes = py_ast.find_all(defnode, ast.AugAssign)
for name in names:
if name in seen_names:
raise TransformError('could not replace variable to const')
potential_assignnodes = [assignnode for assignnode in assignnodes if len(assignnode.targets) == 1 and isinstance(assignnode.targets[0], ast.Name) and assignnode.targets[0].id == name]
potential_augassigns = [assignnode for assignnode in aug_assignnodes if isinstance(assignnode.target, ast.Name) and assignnode.target.id == name]
if len(potential_assignnodes) == 1 and len(potential_augassigns) == 0:
seen_names.append(name)
for namenode in namenodes:
if namenode.id == name:
py_ast.replace_node(dimension_node, namenode, potential_assignnodes[0].value)
return find_const_value(defnode, py_ast.dump_ast(dimension_node), seen_names)
else:
raise TransformError('could not replace variable to const')
评论列表
文章目录