def flatten_array_declarations(root):
class Transformer(ast.NodeTransformer):
def visit_FunctionDef(self, node):
return node
def visit_Assign(self, node):
if isinstance(node.value, ast.Subscript) and isinstance(node.value.value, ast.Call):
subscr = node.value
call = subscr.value
if len(node.targets) > 1:
error.error('Cannot use multiple assignment in array declaration.', node)
variable_name = node.targets[0].id
value_type = call.func.id
declaration_args = call.args
# Get the indices being accessed.
shape = slice_node_to_tuple_of_numbers(subscr.slice)
new_assigns = []
for indices in itertools.product(*[range(n) for n in shape]):
index_name = flattened_array_name(variable_name, indices)
new_index_name_node = ast.copy_location(ast.Name(index_name, ast.Store()), node)
new_value_type_node = ast.copy_location(ast.Name(value_type, ast.Load()), node)
new_declaration_args = [copy.deepcopy(arg) for arg in declaration_args]
new_call_node = ast.copy_location(ast.Call(new_value_type_node, new_declaration_args, [], None, None), node)
new_assign = ast.Assign([new_index_name_node], new_call_node)
new_assign = ast.copy_location(new_assign, node)
new_assigns.append(new_assign)
return new_assigns
else:
return node
return Transformer().visit(root)
评论列表
文章目录