def forced_replace(out, x, y):
"""
Check all internal values of the graph that compute the variable ``out``
for occurrences of values identical with ``x``. If such occurrences are
encountered then they are replaced with variable ``y``.
Parameters
----------
out : Theano Variable
x : Theano Variable
y : Theano Variable
Examples
--------
out := sigmoid(wu)*(1-sigmoid(wu))
x := sigmoid(wu)
forced_replace(out, x, y) := y*(1-y)
"""
if out is None:
return None
# ``visited`` is a set of nodes that are already known and don't need to be
# checked again, speeding up the traversal of multiply-connected graphs.
visited = set()
def local_traverse(graph, x):
if graph in visited:
return []
visited.add(graph)
if equal_computations([graph], [x]):
return [graph]
elif not graph.owner:
return []
else:
rval = []
for inp in graph.owner.inputs:
rval += local_traverse(inp, x)
return rval
to_replace = local_traverse(out, x)
return clone(out, replace=OrderedDict((v, y) for v in to_replace))
评论列表
文章目录