unroller.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:TerpreT 作者: 51alg 项目源码 文件源码
def flatten_array_declarations(root):
    class Transformer(ast.NodeTransformer):
        def visit_FunctionDef(self, node):
            return node

        def visit_Assign(self, node):
            if isinstance(node.value, ast.Subscript) and isinstance(node.value.value, ast.Call):
                subscr = node.value
                call = subscr.value

                if len(node.targets) > 1:
                    error.error('Cannot use multiple assignment in array declaration.', node)

                variable_name = node.targets[0].id
                value_type = call.func.id
                declaration_args = call.args

                # Get the indices being accessed.
                shape = slice_node_to_tuple_of_numbers(subscr.slice)

                new_assigns = []
                for indices in itertools.product(*[range(n) for n in shape]):
                    index_name = flattened_array_name(variable_name, indices)
                    new_index_name_node = ast.copy_location(ast.Name(index_name, ast.Store()), node)
                    new_value_type_node = ast.copy_location(ast.Name(value_type, ast.Load()), node)
                    new_declaration_args = [copy.deepcopy(arg) for arg in declaration_args]
                    new_call_node = ast.copy_location(ast.Call(new_value_type_node, new_declaration_args, [], None, None), node)
                    new_assign = ast.Assign([new_index_name_node], new_call_node)
                    new_assign = ast.copy_location(new_assign, node)
                    new_assigns.append(new_assign)
                return new_assigns
            else:
                return node

    return Transformer().visit(root)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号