def search(node, critereon):
"""
Traverse the Theano graph starting at `node` and return a list of all nodes
which match the `critereon` function. When optimizing a cost function, you
can use this to get a list of all of the trainable params in the graph, like
so:
`lib.search(cost, lambda x: hasattr(x, "param"))`
or
`lib.search(cost, lambda x: hasattr(x, "param") and x.param==True)`
"""
def _search(node, critereon, visited):
if node in visited:
return []
visited.add(node)
results = []
if isinstance(node, T.Apply):
for inp in node.inputs:
results += _search(inp, critereon, visited)
else: # Variable node
if critereon(node):
results.append(node)
if node.owner is not None:
results += _search(node.owner, critereon, visited)
return results
return _search(node, critereon, set())
评论列表
文章目录