def count_trainable_parameters(print_model=False):
"""Count the number of trainable parameters is the current graph.
Returns:
count: the number of trainable parameters"""
total_parameters = 0
for variable in tf.trainable_variables():
# shape is an array of tf.Dimension
shape = variable.get_shape()
if print_model:
print(variable)
variable_parametes = 1
for dim in shape:
variable_parametes *= dim.value
total_parameters += variable_parametes
return total_parameters
评论列表
文章目录