def show_trainable_parameters(verbose=False):
"""Shows the number of trainable parameters in this graph.
Parameters
----------
verbose: Boolean, optional
Show additional information and list the number of trainable
variables per variable, not just the total sum.
"""
total_width = 80
trainable_vars = tf.trainable_variables()
if len(trainable_vars) == 0:
print("No model-params found.")
return
if verbose:
print("-" * total_width)
total_parameters = 0
groups = {}
for var in trainable_vars:
# shape is an array of tf.Dimension
shape = var.get_shape()
var_params = 1
for dim in shape:
var_params *= dim.value
if verbose:
print("{:69} | {:8d}".format(var.name, var_params))
total_parameters += var_params
group_name = var.name.split('/')[0]
if group_name in groups:
groups[group_name] += var_params
else:
groups.update({group_name: var_params})
print("-" * total_width)
for group, count in groups.iteritems():
print("{:69} | {:8d}".format(group, count))
print("=" * total_width)
print("{:69} | {:8d}".format("TOTAL", total_parameters))
print("-" * total_width)
评论列表
文章目录