def tf_parameter_iter(x):
"""Iterate over the left branches of a graph and yield sizes.
Args:
x: root of the subgraph (Tensor, Operation)
Yields:
A triple of name, number of params, and shape.
"""
while 1:
if isinstance(x, tf.Tensor):
shape = x.get_shape().as_list()
x = x.op
else:
shape = ""
left, right = tf_left_split(x)
totals = [tf_num_params(y) for y in right]
total = sum(totals)
yield x.name, total, shape
if left is None: break
x = left
评论列表
文章目录