def __init__(self, returns, *args, **kwargs):
if isinstance(returns, collections.Container):
all = type(returns)(as_op(ret) for ret in returns)
elif isinstance(returns, Op):
all = [as_op(returns)]
elif returns is not None:
raise ValueError()
else:
all = []
self.values = all
self.returns = returns
super(ComputationOp, self).__init__(all=all, **kwargs)
def is_input(arg):
return arg.tensor.is_input
placeholders = self.placeholders()
if len(args) == 1 and args[0] == 'all':
args = placeholders
args = tuple(as_op(arg) for arg in args)
arg_tensors = set(arg.tensor for arg in args)
missing_tensors = [t for t in placeholders - arg_tensors]
if len(missing_tensors) > 0:
raise ValueError(("All used placeholders must be supplied to a "
"computation. Currently missed {}."
).format(missing_tensors))
for arg in args:
if not (arg.tensor.is_input):
raise ValueError((
'The arguments to a computation must all be Ops with property '
'is_input=True, but the op passed had is_input=False.'
'In most cases you want to pass placeholder ops in as arguments. '
'{op} was passed in, of type {op_type}.'
).format(
op=arg,
op_type=arg.__class__.__name__,
))
self.parameters = args
for arg in args:
self.add_control_dep(arg)
评论列表
文章目录