def tf_print(x, depth=0, finished=None, printer=print):
"""A simple print function for a TensorFlow graph.
Args:
x: a tf.Tensor or tf.Operation
depth: current printing depth
finished: set of nodes already output
printer: print function to use
Returns:
Total number of parameters found in the
subtree.
"""
if finished is None:
finished = set()
if isinstance(x, tf.Tensor):
shape = x.get_shape().as_list()
x = x.op
else:
shape = ""
if x.type == "Identity":
x = x.inputs[0].op
if x in finished:
printer("%s<%s> %s %s" % (" "*depth, x.name, x.type, shape))
return
finished |= {x}
printer("%s%s %s %s" % (" "*depth, x.name, x.type, shape))
if not _truncate_structure(x):
for y in x.inputs:
tf_print(y, depth+1, finished, printer=printer)
评论列表
文章目录