compile_tensorflow.py 文件源码

python
阅读 53 收藏 0 点赞 0 评论 0

项目:TerpreT 作者: 51alg 项目源码 文件源码
def split_params_from_model(self, root):
        class CollectParamsVisitor(ast.NodeVisitor):
            def visit_Module(self_c, node):
                self_c.param_statements = []
                self_c.param_dict = None
                self_c.generic_visit(node)

            def visit_Assign(self_c, node):
                if u.is_param_declaration(node):
                    self_c.param_statements.append(node)
                elif u.is_self_params_assignment(node):
                    self_c.param_dict = node
        cpv = CollectParamsVisitor()
        cpv.visit(root)

        class RemoveParamsTransformer(self.MyTransformer):
            def visit_FunctionDef(self_r, node):
                """ Don't recurse into user defined functions """
                return node

            def visit_Assign(self_r, node):
                if u.is_param_declaration(node):
                    return []
                elif u.is_self_params_assignment(node):
                    return []
                else:
                    return node
        rpt = RemoveParamsTransformer()
        root = rpt.visit(root)

        return root, cpv.param_statements, cpv.param_dict
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号