def tf_num_params(x):
"""Number of parameters in a TensorFlow subgraph.
Args:
x: root of the subgraph (Tensor, Operation)
Returns:
Total number of elements found in all Variables
in the subgraph.
"""
if isinstance(x, tf.Tensor):
shape = x.get_shape()
x = x.op
if x.type == "Variable":
return shape.num_elements()
totals = [tf_num_params(y) for y in x.inputs]
return sum(totals)
评论列表
文章目录