def count_trainable_param_number():
"""Count total number of parameters of trainable parameters.
"""
total_parameters = 0
for variable in tf.trainable_variables():
# shape is an array of tf.Dimension
shape = variable.get_shape()
variable_parametes = 1
for dim in shape:
variable_parametes *= dim.value
total_parameters += variable_parametes
return total_parameters
评论列表
文章目录