def __init__(self, inputs, outputs, **kwargs):
if not isinstance(outputs, list):
raise TypeError('outputs must be list', outputs)
for i in inputs + outputs:
if not isinstance(i, gof.Variable):
raise TypeError(
'inputs and outputs must be Variable instances', i)
if 'updates' in kwargs or 'givens' in kwargs:
raise TypeError('updates and givens are not allowed in kwargs')
# To support correctly shared variables the inner fct should
# not see them. Otherwise their is problem with the gradient.
self.shared_inputs = [var for var in gof.graph.inputs(outputs)
if isinstance(var, SharedVariable)]
shared_vars = [var.type() for var in self.shared_inputs]
new = rebuild_collect_shared(outputs, inputs=inputs + shared_vars,
replace=dict(izip(self.shared_inputs,
shared_vars)),
copy_inputs_over=False)
(new_inputs, new_outputs,
[clone_d, update_d, update_expr, shared_inputs]) = new
assert len(new_inputs) == len(inputs) + len(self.shared_inputs)
assert len(new_outputs) == len(outputs)
assert not update_d
assert not update_expr
assert not shared_inputs
self.new_inputs = new_inputs
self.new_outputs = new_outputs
self.inputs = inputs
self.outputs = outputs
self.kwargs = kwargs
self.input_types = [input.type for input in inputs]
self.output_types = [output.type for output in outputs]
评论列表
文章目录