def add_input_indices(root, input_vars, index_var):
class AddInputIndicesVisitor(ast.NodeTransformer):
def visit_Subscript(self, node):
if get_var_name(node) in input_vars:
return extend_subscript_for_input(node, index_var)
return node
def visit_Name(self, node):
if node.id in input_vars:
return ast.Subscript(node, ast.Index(index_var), node.ctx)
return node
vis = AddInputIndicesVisitor()
root = vis.visit(root)
return ast.fix_missing_locations(root)
评论列表
文章目录