def __init__(self, stmt, context):
self.stmt = stmt
self.context = context
self.stmt_table = {
ast.Expr: self.expr,
ast.Pass: self.parse_pass,
ast.AnnAssign: self.ann_assign,
ast.Assign: self.assign,
ast.If: self.parse_if,
ast.Call: self.call,
ast.Assert: self.parse_assert,
ast.For: self.parse_for,
ast.AugAssign: self.aug_assign,
ast.Break: self.parse_break,
ast.Return: self.parse_return,
}
stmt_type = self.stmt.__class__
if stmt_type in self.stmt_table:
self.lll_node = self.stmt_table[stmt_type]()
elif isinstance(stmt, ast.Name) and stmt.id == "throw":
self.lll_node = LLLnode.from_list(['assert', 0], typ=None, pos=getpos(stmt))
else:
raise StructureException("Unsupported statement type", stmt)
python类AugAssign()的实例源码
def getAllGlobalNames(a):
# Finds all names that can be accessed at the global level in the AST
if type(a) != ast.Module:
return []
names = []
for obj in a.body:
if type(obj) in [ast.FunctionDef, ast.ClassDef]:
names.append(obj.name)
elif type(obj) in [ast.Assign, ast.AugAssign]:
targets = obj.targets if type(obj) == ast.Assign else [obj.target]
for target in obj.targets:
if type(target) == ast.Name:
names.append(target.id)
elif type(target) in [ast.Tuple, ast.List]:
for elt in target.elts:
if type(elt) == ast.Name:
names.append(elt.id)
elif type(obj) in [ast.Import, ast.ImportFrom]:
for module in obj.names:
names.append(module.asname if module.asname != None else module.name)
return names
def visit_Assert(self, node):
"""Replace assertions with augmented assignments."""
self.max_score += 1
return ast.AugAssign(
op=ast.Add(),
target=ast.Name(
id=self.score_var_name,
ctx=ast.Store()
),
value=ast.Call(
args=[node.test],
func=ast.Name(
id='bool',
ctx=ast.Load()
),
keywords=[],
kwargs=None,
starargs=None
)
)
def get_coverable_nodes(cls):
return {
ast.Assert,
ast.Assign,
ast.AugAssign,
ast.Break,
ast.Continue,
ast.Delete,
ast.Expr,
ast.Global,
ast.Import,
ast.ImportFrom,
ast.Nonlocal,
ast.Pass,
ast.Raise,
ast.Return,
ast.FunctionDef,
ast.ClassDef,
ast.TryExcept,
ast.TryFinally,
ast.ExceptHandler,
ast.If,
ast.For,
ast.While,
}
def get_coverable_nodes(cls):
return {
ast.Assert,
ast.Assign,
ast.AugAssign,
ast.Break,
ast.Continue,
ast.Delete,
ast.Expr,
ast.Global,
ast.Import,
ast.ImportFrom,
ast.Nonlocal,
ast.Pass,
ast.Raise,
ast.Return,
ast.ClassDef,
ast.FunctionDef,
ast.Try,
ast.ExceptHandler,
ast.If,
ast.For,
ast.While,
}
def allVariableNamesUsed(a):
"""Gathers all the variable names used in the ast"""
if not isinstance(a, ast.AST):
return []
elif type(a) == ast.Name:
return [a.id]
elif type(a) == ast.Assign:
"""In assignments, ignore all pure names used- they're being assigned to, not used"""
variables = allVariableNamesUsed(a.value)
for target in a.targets:
if type(target) == ast.Name:
pass
elif type(target) in [ast.Tuple, ast.List]:
for elt in target.elts:
if type(elt) != ast.Name:
variables += allVariableNamesUsed(elt)
else:
variables += allVariableNamesUsed(target)
return variables
elif type(a) == ast.AugAssign:
variables = allVariableNamesUsed(a.value)
variables += allVariableNamesUsed(a.target)
return variables
variables = []
for child in ast.iter_child_nodes(a):
variables += allVariableNamesUsed(child)
return variables
def staticVars(l, vars):
"""Determines whether the given lines change the given variables"""
# First, if one of the variables can be modified, there might be a problem
mutableVars = []
for var in vars:
if (not (hasattr(var, "type") and (var.type in [int, float, str, bool]))):
mutableVars.append(var)
for i in range(len(l)):
if type(l[i]) == ast.Assign:
for var in vars:
if var.id in allVariableNamesUsed(l[i].targets[0]):
return False
elif type(l[i]) == ast.AugAssign:
for var in vars:
if var.id in allVariableNamesUsed(l[i].target):
return False
elif type(l[i]) in [ast.If, ast.While]:
if not (staticVars(l[i].body, vars) and staticVars(l[i].orelse, vars)):
return False
elif type(l[i]) == ast.For:
for var in vars:
if var.id in allVariableNamesUsed(l[i].target):
return False
if not (staticVars(l[i].body, vars) and staticVars(l[i].orelse, vars)):
return False
elif type(l[i]) in [ast.FunctionDef, ast.ClassDef, ast.Try, ast.With]:
log("transformations\tstaticVars\tMissing type: " + str(type(l[i])), "bug")
# If a mutable variable is used, we can't trust it
for var in mutableVars:
if var.id in allVariableNamesUsed(l[i]):
return False
return True
def isStatement(a):
"""Determine whether the given node is a statement (vs an expression)"""
return type(a) in [ ast.Module, ast.Interactive, ast.Expression, ast.Suite,
ast.FunctionDef, ast.ClassDef, ast.Return, ast.Delete,
ast.Assign, ast.AugAssign, ast.For, ast.While,
ast.If, ast.With, ast.Raise, ast.Try,
ast.Assert, ast.Import, ast.ImportFrom, ast.Global,
ast.Expr, ast.Pass, ast.Break, ast.Continue ]
def getAllAssignedVarIds(a):
if not isinstance(a, ast.AST):
return []
ids = []
for child in ast.walk(a):
if type(child) == ast.Assign:
ids += gatherAssignedVarIds(child.targets)
elif type(child) == ast.AugAssign:
ids += gatherAssignedVarIds([child.target])
elif type(child) == ast.For:
ids += gatherAssignedVarIds([child.target])
return ids
def test_augassign(self):
aug = ast.AugAssign(ast.Name("x", ast.Load()), ast.Add(),
ast.Name("y", ast.Load()))
self.stmt(aug, "must have Store context")
aug = ast.AugAssign(ast.Name("x", ast.Store()), ast.Add(),
ast.Name("y", ast.Store()))
self.stmt(aug, "must have Load context")
def find_const_value(defnode, arg_str, seen_names):
"""
given arg_str, which usually represents a dimension size of an array
eg: a // 4 + 8
try replace variables with constants
"""
try:
value = eval(arg_str)
return value
except:
dimension_node = py_ast.get_ast(arg_str).body[0].value
namenodes = py_ast.find_all(dimension_node, ast.Name)
names = []
for namenode in namenodes:
if namenode.id not in names:
names.append(namenode.id)
assignnodes = py_ast.find_all(defnode, ast.Assign)
aug_assignnodes = py_ast.find_all(defnode, ast.AugAssign)
for name in names:
if name in seen_names:
raise TransformError('could not replace variable to const')
potential_assignnodes = [assignnode for assignnode in assignnodes if len(assignnode.targets) == 1 and isinstance(assignnode.targets[0], ast.Name) and assignnode.targets[0].id == name]
potential_augassigns = [assignnode for assignnode in aug_assignnodes if isinstance(assignnode.target, ast.Name) and assignnode.target.id == name]
if len(potential_assignnodes) == 1 and len(potential_augassigns) == 0:
seen_names.append(name)
for namenode in namenodes:
if namenode.id == name:
py_ast.replace_node(dimension_node, namenode, potential_assignnodes[0].value)
return find_const_value(defnode, py_ast.dump_ast(dimension_node), seen_names)
else:
raise TransformError('could not replace variable to const')
def convert(self):
code_type = {
ast.Assign: self.assign_code,
ast.Expr: self.expr_code,
ast.AugAssign: self.aug_assign_code,
}
for func_name in self.source_code.func_key:
self.processing_func = self.get_func_by_name(func_name)
for source in self.input_source_code:
parse = ast.parse(source)
for body in parse.body:
code_type[body.__class__](body)
self.__write_file()
def __init__(self, name, source, scope):
if '__all__' in scope and isinstance(source, ast.AugAssign):
self.names = list(scope['__all__'].names)
else:
self.names = []
if isinstance(source.value, (ast.List, ast.Tuple)):
for node in source.value.elts:
if isinstance(node, ast.Str):
self.names.append(node.s)
super(ExportBinding, self).__init__(name, source)
def test_augassign(self):
aug = ast.AugAssign(ast.Name("x", ast.Load()), ast.Add(),
ast.Name("y", ast.Load()))
self.stmt(aug, "must have Store context")
aug = ast.AugAssign(ast.Name("x", ast.Store()), ast.Add(),
ast.Name("y", ast.Store()))
self.stmt(aug, "must have Load context")
def _translate_body(self, body, allow_loose_in_edges=False, allow_loose_out_edges=False):
cfg_factory = CfgFactory(self._id_gen)
for child in body:
if isinstance(child, (ast.Assign, ast.AugAssign, ast.Expr)):
cfg_factory.add_stmts(self.visit(child))
elif isinstance(child, ast.If):
cfg_factory.complete_basic_block()
if_cfg = self.visit(child)
cfg_factory.append_cfg(if_cfg)
elif isinstance(child, ast.While):
cfg_factory.complete_basic_block()
while_cfg = self.visit(child)
cfg_factory.append_cfg(while_cfg)
elif isinstance(child, ast.Break):
cfg_factory.complete_basic_block()
break_cfg = self.visit(child)
cfg_factory.append_cfg(break_cfg)
elif isinstance(child, ast.Continue):
cfg_factory.complete_basic_block()
cont_cfg = self.visit(child)
cfg_factory.append_cfg(cont_cfg)
elif isinstance(child, ast.Pass):
if cfg_factory.incomplete_block():
pass
else:
cfg_factory.append_cfg(_dummy_cfg(self._id_gen))
else:
raise NotImplementedError(f"The statement {str(type(child))} is not yet translatable to CFG!")
cfg_factory.complete_basic_block()
if not allow_loose_in_edges and cfg_factory.cfg and cfg_factory.cfg.loose_in_edges:
cfg_factory.prepend_cfg(_dummy_cfg(self._id_gen))
if not allow_loose_out_edges and cfg_factory.cfg and cfg_factory.cfg.loose_out_edges:
cfg_factory.append_cfg(_dummy_cfg(self._id_gen))
return cfg_factory.cfg
def infer(node, context, solver):
if isinstance(node, ast.Assign):
return _infer_assign(node, context, solver)
elif isinstance(node, ast.AugAssign):
return _infer_augmented_assign(node, context, solver)
elif isinstance(node, ast.Return):
if not node.value:
return solver.z3_types.none
return expr.infer(node.value, context, solver)
elif isinstance(node, ast.Delete):
return _infer_delete(node, context, solver)
elif isinstance(node, (ast.If, ast.While)):
return _infer_control_flow(node, context, solver)
elif isinstance(node, ast.For):
return _infer_for(node, context, solver)
elif sys.version_info[0] >= 3 and sys.version_info[1] >= 5 and isinstance(node, ast.AsyncFor):
# AsyncFor is introduced in Python 3.5
return _infer_for(node, context, solver)
elif isinstance(node, ast.With):
return _infer_with(node, context, solver)
elif sys.version_info[0] >= 3 and sys.version_info[1] >= 5 and isinstance(node, ast.AsyncWith):
# AsyncWith is introduced in Python 3.5
return _infer_with(node, context, solver)
elif isinstance(node, ast.Try):
return _infer_try(node, context, solver)
elif isinstance(node, ast.FunctionDef):
return _infer_func_def(node, context, solver)
elif isinstance(node, ast.ClassDef):
return _infer_class_def(node, context, solver)
elif isinstance(node, ast.Expr):
expr.infer(node.value, context, solver)
elif isinstance(node, ast.Import):
return _infer_import(node, context, solver)
elif isinstance(node, ast.ImportFrom):
return _infer_import_from(node, context, solver)
return solver.z3_types.none
def __init__(self, name, source, scope):
if '__all__' in scope and isinstance(source, ast.AugAssign):
self.names = list(scope['__all__'].names)
else:
self.names = []
if isinstance(source.value, (ast.List, ast.Tuple)):
for node in source.value.elts:
if isinstance(node, ast.Str):
self.names.append(node.s)
super(ExportBinding, self).__init__(name, source)
def should_mutate(self, node):
return not isinstance(node.parent, ast.AugAssign)
def should_mutate(self, node):
return isinstance(node.parent, ast.AugAssign)
def test_augassign(self):
aug = ast.AugAssign(ast.Name("x", ast.Load()), ast.Add(),
ast.Name("y", ast.Load()))
self.stmt(aug, "must have Store context")
aug = ast.AugAssign(ast.Name("x", ast.Store()), ast.Add(),
ast.Name("y", ast.Store()))
self.stmt(aug, "must have Load context")
def __init__(self, name, source, scope):
if '__all__' in scope and isinstance(source, ast.AugAssign):
self.names = list(scope['__all__'].names)
else:
self.names = []
if isinstance(source.value, (ast.List, ast.Tuple)):
for node in source.value.elts:
if isinstance(node, ast.Str):
self.names.append(node.s)
super(ExportBinding, self).__init__(name, source)
def augAssignSpecialFunction(cv, orig):
if (not isinstance(cv, DeleteVector)) and (not isStatement(cv.oldSubtree)) and \
(childHasTag(cv.oldSubtree, "augAssignVal") or childHasTag(cv.oldSubtree, "augAssignBinOp")):
# First, create the oldTree and newTree in full
cvCopy = cv.deepcopy()
cvCopy.start = deepcopy(cv.start)
newTree = cvCopy.applyChange(caller="augAssignSpecialFunction")
# This should be in an augassign, move up in the tree until we reach it.
spot = cv.oldSubtree
cvCopy = cv
i = 0
while type(spot) not in [ast.Assign, ast.AugAssign] and len(cvCopy.path) > i:
i += 1
cvCopy = cv.deepcopy()
cvCopy.path = cv.path[i:]
spot = deepcopy(cvCopy.traverseTree(cv.start))
# Double check to make sure this is actually still an augassign
if type(spot) in [ast.Assign, ast.AugAssign] and hasattr(spot, "global_id"):
newCv = cv.deepcopy()
newCv.path = cv.path[i+1:]
newCv.oldSubtree = spot
# find the new spot
cvCopy = cv.deepcopy()
cvCopy.path = cv.path[i:]
newSpot = cvCopy.traverseTree(newTree)
if type(newSpot) == type(spot):
# Don't do special things when they aren't needed
if type(newSpot) == ast.Assign:
if compareASTs(newSpot.targets, spot.targets, checkEquality=True) == 0:
# If the two have the same targets and are both binary operations with the target as the left value...
# just change the value
if type(newSpot.value) == type(spot.value) == ast.BinOp:
if compareASTs(spot.targets[0], spot.value.left, checkEquality=True) == 0 and \
compareASTs(newSpot.targets[0], newSpot.value.left, checkEquality=True) == 0:
# we just want to change the values
return ChangeVector([("right", "Binary Operation"), ("value", "Assign")] + newCv.path,
spot.value.right, newSpot.value.right, newCv.start)
elif compareASTs(newSpot.value, spot.value, checkEquality=True) == 0:
return cv
else:
log("Assign", "bug")
elif type(newSpot) == ast.AugAssign:
diffCount = 0
if compareASTs(newSpot.op, spot.op, checkEquality=True) != 0:
diffCount += 1
if compareASTs(newSpot.target, spot.target, checkEquality=True) != 0:
diffCount += 1
if compareASTs(newSpot.value, spot.value, checkEquality=True) != 0:
diffCount += 1
if diffCount == 1:
return cv
else:
log("AugAssign: " + str(diffCount), "bug")
else:
log("Mismatched types: " + str(type(newSpot)) + "," + str(type(spot)), "bug")
return ChangeVector(newCv.path, spot, newSpot, start=newCv.start)
return cv
def is_before(node1, node2):
"""
checks if definately appears before node2
"""
parents1 = []
current = node1.parent
#first get parent_list of node1
while current is not None:
if isinstance(current, ast.Assign) or isinstance(current, ast.AugAssign):
parent1 = current
break
parents1.append(current)
current = current.parent
current = node2
if current in parents1:
return False
while current is not None:
try:
"""if current.parent == parent1.parent:
parent_field = current.parent_field
field_list = getattr(current.parent, parent_field)
list_index1 = field_list.index(parent1)
list_index2 = field_list.index(current)
if list_index2 > list_index1:
return True"""
if current.parent == parent1.parent:
for field, value in ast.iter_fields(parent1.parent):
if value == current or value == parent1:
return False
elif isinstance(value, list) and current in value and parent1 in value:
list_index1 = value.index(parent1)
list_index2 = value.index(current)
if list_index2 > list_index1:
return True
except:
return False
current = current.parent
return False