def __init__(self, node, attribute=None):
"""Initializes DotNode.
Args:
node: :class: `Variable` object or :class: `Function` object.
attribute (dict): Attributes for the node.
"""
assert isinstance(node, (variable.Variable, function.Function))
self.node = node
self.id_ = id(node)
self.attribute = {'label': node.label}
if isinstance(node, variable.Variable):
self.attribute = {'shape': 'oval'}
else:
self.attribute = {'shape': 'box'}
if attribute is not None:
self.attribute.update(attribute)
python类Function()的实例源码
def static_backward(func):
"""Decorator to mark a function for inclusion in the backward schedule.
The decorator is used in the same way as `static_forward` except that
it is used to decorate the functions that should be added to the static
backward-pass schedule. The wrapped function implements the
computations of the `backward()` method of a Function instance that
must be executed during every backward pass.
Similarly to `static_forward`, the wrapped function should not return
a result because it will be ignored.
Args:
func: A a backward-pass method of a sub-class of Function that should be inserted
into the static backward schedule when `static_graph` is enabled. The function
must not return a value because any return values will be ignored.
Returns: The wrapped function.
"""
def wrapped_func(*args, **kwargs):
# Save arguments, function, and results pointers/references to the schedule list:
def no_arg_func():
#print('In no_arg_func: Calling: ', func)
func(*args, **kwargs)
#print("Arguments were: %s, %s" % (args, kwargs))
# no_arg_func() requires no arguments to call since the arguments of the decorated function
# are captured by the closure.
no_arg_func()
inst = args[0]
assert isinstance(inst, function.Function)
schedule_function = getattr(inst, 'schedule_func', None)
# If trace mode is on, add to schedule.
if schedule_function is not None:
print('Adding function to the backward static schedule.')
schedule_function.append_backward_function(no_arg_func)
return wrapped_func
def __init__(self, nodes, edges, variable_style=None, function_style=None):
"""Initializes computational graph.
Args:
nodes (list): List of nodes. Each node is either
:class:`Variable` object or :class:`Function` object.
edges (list): List of edges. Each edge consists of pair of nodes.
variable_style (dict): Dot node style for variable.
function_style (dict): Dot node style for function.
"""
self.nodes = nodes
self.edges = edges
self.variable_style = variable_style
self.function_style = function_style
def _to_dot(self):
"""Converts graph in dot format.
`label` property of is used as short description of each node.
Returns:
str: The graph in dot format.
"""
ret = "digraph graphname{"
for node in self.nodes:
assert isinstance(node, (variable.Variable, function.Function))
if isinstance(node, variable.Variable):
ret += DotNode(node, self.variable_style).label
else:
ret += DotNode(node, self.function_style).label
for edge in self.edges:
head, tail = edge
if (isinstance(head, variable.Variable) and
isinstance(tail, function.Function)):
head_attr = self.variable_style
tail_attr = self.function_style
elif (isinstance(head, function.Function) and
isinstance(tail, variable.Variable)):
head_attr = self.function_style
tail_attr = self.variable_style
else:
raise TypeError(
'head and tail should be the set of Variable and Function')
head_node = DotNode(head, head_attr)
tail_node = DotNode(tail, tail_attr)
ret += "%s -> %s;" % (head_node.id_, tail_node.id_)
ret += "}"
return ret