op_graph.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def __init__(self, returns, *args, **kwargs):
        if isinstance(returns, collections.Container):
            all = type(returns)(as_op(ret) for ret in returns)
        elif isinstance(returns, Op):
            all = [as_op(returns)]
        elif returns is not None:
            raise ValueError()
        else:
            all = []

        self.values = all
        self.returns = returns
        super(ComputationOp, self).__init__(all=all, **kwargs)

        def is_input(arg):
            return arg.tensor.is_input

        placeholders = self.placeholders()
        if len(args) == 1 and args[0] == 'all':
            args = placeholders

        args = tuple(as_op(arg) for arg in args)
        arg_tensors = set(arg.tensor for arg in args)
        missing_tensors = [t for t in placeholders - arg_tensors]

        if len(missing_tensors) > 0:
            raise ValueError(("All used placeholders must be supplied to a "
                              "computation. Currently missed {}."
                              ).format(missing_tensors))

        for arg in args:
            if not (arg.tensor.is_input):
                raise ValueError((
                    'The arguments to a computation must all be Ops with property '
                    'is_input=True, but the op passed had is_input=False.'
                    'In most cases you want to pass placeholder ops in as arguments.  '
                    '{op} was passed in, of type {op_type}.'
                ).format(
                    op=arg,
                    op_type=arg.__class__.__name__,
                ))

        self.parameters = args
        for arg in args:
            self.add_control_dep(arg)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号